소스 검색

HPCC-14623 Allow SUM(GROUP)/y inside a TABLE activity

Signed-off-by: Gavin Halliday <gavin.halliday@lexisnexis.com>
Gavin Halliday 9 년 전
부모
커밋
66551fde28
3개의 변경된 파일212개의 추가작업 그리고 185개의 파일을 삭제
  1. 200 185
      ecl/hqlcpp/hqlttcpp.cpp
  2. 1 0
      testing/regress/ecl/aggds1.ecl
  3. 11 0
      testing/regress/ecl/key/aggds1.xml

+ 200 - 185
ecl/hqlcpp/hqlttcpp.cpp

@@ -1505,207 +1505,220 @@ static IHqlExpression * createArith(node_operator op, ITypeInfo * type, IHqlExpr
 }
 
 
-
-//MORE: This might be better as a class, it would reduce the number of parameters
-IHqlExpression * doNormalizeAggregateExpr(IHqlExpression * selector, IHqlExpression * expr, HqlExprArray & fields, HqlExprArray & assigns, bool & extraSelectNeeded, bool canOptimizeCasts);
-IHqlExpression * evalNormalizeAggregateExpr(IHqlExpression * selector, IHqlExpression * expr, HqlExprArray & fields, HqlExprArray & assigns, bool & extraSelectNeeded, bool canOptimizeCasts)
+class AggregateNormalizer
 {
-    switch (expr->getOperator())
+public:
+    AggregateNormalizer(IHqlExpression * _rootSelector, HqlExprArray & _fields, HqlExprArray & _assigns, bool & _extraSelectNeeded)
+       : rootSelector(_rootSelector), fields(_fields), assigns(_assigns), extraSelectNeeded(_extraSelectNeeded)
     {
-    case no_avegroup:
-        //Map this to sum(x)/count(x)
-        {
-            IHqlExpression * arg = expr->queryChild(0);
-            IHqlExpression * cond = expr->queryChild(1);
-
-            Owned<ITypeInfo> sumType = getSumAggType(arg);
-            ITypeInfo * exprType = expr->queryType();
-            OwnedHqlExpr sum = createValue(no_sumgroup, LINK(sumType), LINK(arg), LINK(cond));
-            OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
-
-            //average should be done as a real operation I think, possibly decimal, if argument is decimal
-            OwnedHqlExpr avg = createArith(no_div, exprType, sum, count);
-            return doNormalizeAggregateExpr(selector, avg, fields, assigns, extraSelectNeeded, false);
-        }
-    case no_vargroup:
-        //Map this to (sum(x^2)-sum(x)^2/count())/count()
-        {
-            IHqlExpression * arg = expr->queryChild(0);
-            IHqlExpression * cond = expr->queryChild(1);
-
-            ITypeInfo * exprType = expr->queryType();
-            OwnedHqlExpr xx = createArith(no_mul, exprType, arg, arg);
-            OwnedHqlExpr sumxx = createValue(no_sumgroup, LINK(exprType), LINK(xx), LINK(cond));
-            OwnedHqlExpr sumx = createValue(no_sumgroup, LINK(exprType), LINK(arg), LINK(cond));
-            OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
-
-            //average should be done as a real operation I think, possibly decimal, if argument is decimal
-            OwnedHqlExpr n1 = createArith(no_mul, exprType, sumx, sumx);
-            OwnedHqlExpr n2 = createArith(no_div, exprType, n1, count);
-            OwnedHqlExpr n3 = createArith(no_sub, exprType, sumxx, n2);
-            OwnedHqlExpr n4 = createArith(no_div, exprType, n3, count);
-            return doNormalizeAggregateExpr(selector, n4, fields, assigns, extraSelectNeeded, false);
-        }
-    case no_covargroup:
-        //Map this to (sum(x.y)-sum(x).sum(y)/count())/count()
-        {
-            IHqlExpression * argX = expr->queryChild(0);
-            IHqlExpression * argY = expr->queryChild(1);
-            IHqlExpression * cond = expr->queryChild(2);
-
-            ITypeInfo * exprType = expr->queryType();
-            OwnedHqlExpr xy = createArith(no_mul, exprType, argX, argY);
-            OwnedHqlExpr sumxy = createValue(no_sumgroup, LINK(exprType), LINK(xy), LINK(cond));
-            OwnedHqlExpr sumx = createValue(no_sumgroup, LINK(exprType), LINK(argX), LINK(cond));
-            OwnedHqlExpr sumy = createValue(no_sumgroup, LINK(exprType), LINK(argY), LINK(cond));
-            OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
-
-            //average should be done as a real operation I think, possibly decimal, if argument is decimal
-            OwnedHqlExpr n1 = createArith(no_mul, exprType, sumx, sumy);
-            OwnedHqlExpr n2 = createArith(no_div, exprType, n1, count);
-            OwnedHqlExpr n3 = createArith(no_sub, exprType, sumxy, n2);
-            OwnedHqlExpr n4 = createArith(no_div, exprType, n3, count);
-            return doNormalizeAggregateExpr(selector, n4, fields, assigns, extraSelectNeeded, false);
-        }
-    case no_corrgroup:
-        //Map this to (covar(x,y)/(var(x).var(y)))
-        //== (sum(x.y)*count() - sum(x).sum(y))/sqrt((sum(x.x)*count()-sum(x)^2) * (sum(y.y)*count()-sum(y)^2))
-        {
-            IHqlExpression * argX = expr->queryChild(0);
-            IHqlExpression * argY = expr->queryChild(1);
-            IHqlExpression * cond = expr->queryChild(2);
-
-            ITypeInfo * exprType = expr->queryType();
-            OwnedHqlExpr xx = createArith(no_mul, exprType, argX, argX);
-            OwnedHqlExpr sumxx = createValue(no_sumgroup, LINK(exprType), LINK(xx), LINK(cond));
-            OwnedHqlExpr xy = createArith(no_mul, exprType, argX, argY);
-            OwnedHqlExpr sumxy = createValue(no_sumgroup, LINK(exprType), LINK(xy), LINK(cond));
-            OwnedHqlExpr yy = createArith(no_mul, exprType, argY, argY);
-            OwnedHqlExpr sumyy = createValue(no_sumgroup, LINK(exprType), LINK(yy), LINK(cond));
-            OwnedHqlExpr sumx = createValue(no_sumgroup, LINK(exprType), LINK(argX), LINK(cond));
-            OwnedHqlExpr sumy = createValue(no_sumgroup, LINK(exprType), LINK(argY), LINK(cond));
-            OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
-
-            OwnedHqlExpr n1 = createArith(no_mul, exprType, sumxy, count);
-            OwnedHqlExpr n2 = createArith(no_mul, exprType, sumx, sumy);
-            OwnedHqlExpr n3 = createArith(no_sub, exprType, n1, n2);
-
-            OwnedHqlExpr n4 = createArith(no_mul, exprType, sumxx, count);
-            OwnedHqlExpr n5 = createArith(no_mul, exprType, sumx, sumx);
-            OwnedHqlExpr n6 = createArith(no_sub, exprType, n4, n5);
-
-            OwnedHqlExpr n7 = createArith(no_mul, exprType, sumyy, count);
-            OwnedHqlExpr n8 = createArith(no_mul, exprType, sumy, sumy);
-            OwnedHqlExpr n9 = createArith(no_sub, exprType, n7, n8);
-
-            OwnedHqlExpr n10 = createArith(no_mul, exprType, n6, n9);
-            OwnedHqlExpr n11 = createValue(no_sqrt, LINK(exprType), LINK(n10));
-            OwnedHqlExpr n12 = createArith(no_div, exprType, n3, n11);
-            return doNormalizeAggregateExpr(selector, n12, fields, assigns, extraSelectNeeded, false);
-        }
-        throwUnexpected();
+    }
+    IHqlExpression * normalizeAggregateExpr(IHqlExpression * selector, IHqlExpression * expr, bool canOptimizeCasts)
+    {
+        IHqlExpression * match = static_cast<IHqlExpression *>(expr->queryTransformExtra());
+        if (match)
+            return LINK(match);
+        IHqlExpression * ret = evalNormalizeAggregateExpr(selector, expr, canOptimizeCasts);
+        expr->setTransformExtra(ret);
+        return ret;
+    }
 
-    case no_variance:
-    case no_covariance:
-    case no_correlation:
-        throwUnexpectedOp(expr->getOperator());
 
-    case no_count:
-    case no_sum:
-    case no_max:
-    case no_min:
-    case no_ave:
-    case no_select:
-    case no_exists:
-    case no_field:
-        // a count on a child dataset or something else - add it as it is...
-        //goes wrong for count(group)*
-        return LINK(expr);
-    case no_countgroup:
-    case no_sumgroup:
-    case no_maxgroup:
-    case no_mingroup:
-    case no_existsgroup:
+protected:
+    IHqlExpression * evalNormalizeAggregateExpr(IHqlExpression * selector, IHqlExpression * expr, bool canOptimizeCasts)
+    {
+        if (!expr->isGroupAggregateFunction() && !expr->usesSelector(rootSelector))
+            return LINK(expr);
+
+        switch (expr->getOperator())
         {
-            ForEachItemIn(idx, assigns)
+        case no_avegroup:
+            //Map this to sum(x)/count(x)
+            {
+                IHqlExpression * arg = expr->queryChild(0);
+                IHqlExpression * cond = expr->queryChild(1);
+
+                Owned<ITypeInfo> sumType = getSumAggType(arg);
+                ITypeInfo * exprType = expr->queryType();
+                OwnedHqlExpr sum = createValue(no_sumgroup, LINK(sumType), LINK(arg), LINK(cond));
+                OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
+
+                //average should be done as a real operation I think, possibly decimal, if argument is decimal
+                OwnedHqlExpr avg = createArith(no_div, exprType, sum, count);
+                return normalizeAggregateExpr(selector, avg, false);
+            }
+        case no_vargroup:
+            //Map this to (sum(x^2)-sum(x)^2/count())/count()
             {
-                IHqlExpression & cur = assigns.item(idx);
-                if (cur.queryChild(1) == expr)
+                IHqlExpression * arg = expr->queryChild(0);
+                IHqlExpression * cond = expr->queryChild(1);
+
+                ITypeInfo * exprType = expr->queryType();
+                OwnedHqlExpr xx = createArith(no_mul, exprType, arg, arg);
+                OwnedHqlExpr sumxx = createValue(no_sumgroup, LINK(exprType), LINK(xx), LINK(cond));
+                OwnedHqlExpr sumx = createValue(no_sumgroup, LINK(exprType), LINK(arg), LINK(cond));
+                OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
+
+                //average should be done as a real operation I think, possibly decimal, if argument is decimal
+                OwnedHqlExpr n1 = createArith(no_mul, exprType, sumx, sumx);
+                OwnedHqlExpr n2 = createArith(no_div, exprType, n1, count);
+                OwnedHqlExpr n3 = createArith(no_sub, exprType, sumxx, n2);
+                OwnedHqlExpr n4 = createArith(no_div, exprType, n3, count);
+                return normalizeAggregateExpr(selector, n4, false);
+            }
+        case no_covargroup:
+            //Map this to (sum(x.y)-sum(x).sum(y)/count())/count()
+            {
+                IHqlExpression * argX = expr->queryChild(0);
+                IHqlExpression * argY = expr->queryChild(1);
+                IHqlExpression * cond = expr->queryChild(2);
+
+                ITypeInfo * exprType = expr->queryType();
+                OwnedHqlExpr xy = createArith(no_mul, exprType, argX, argY);
+                OwnedHqlExpr sumxy = createValue(no_sumgroup, LINK(exprType), LINK(xy), LINK(cond));
+                OwnedHqlExpr sumx = createValue(no_sumgroup, LINK(exprType), LINK(argX), LINK(cond));
+                OwnedHqlExpr sumy = createValue(no_sumgroup, LINK(exprType), LINK(argY), LINK(cond));
+                OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
+
+                //average should be done as a real operation I think, possibly decimal, if argument is decimal
+                OwnedHqlExpr n1 = createArith(no_mul, exprType, sumx, sumy);
+                OwnedHqlExpr n2 = createArith(no_div, exprType, n1, count);
+                OwnedHqlExpr n3 = createArith(no_sub, exprType, sumxy, n2);
+                OwnedHqlExpr n4 = createArith(no_div, exprType, n3, count);
+                return normalizeAggregateExpr(selector, n4, false);
+            }
+        case no_corrgroup:
+            //Map this to (covar(x,y)/(var(x).var(y)))
+            //== (sum(x.y)*count() - sum(x).sum(y))/sqrt((sum(x.x)*count()-sum(x)^2) * (sum(y.y)*count()-sum(y)^2))
+            {
+                IHqlExpression * argX = expr->queryChild(0);
+                IHqlExpression * argY = expr->queryChild(1);
+                IHqlExpression * cond = expr->queryChild(2);
+
+                ITypeInfo * exprType = expr->queryType();
+                OwnedHqlExpr xx = createArith(no_mul, exprType, argX, argX);
+                OwnedHqlExpr sumxx = createValue(no_sumgroup, LINK(exprType), LINK(xx), LINK(cond));
+                OwnedHqlExpr xy = createArith(no_mul, exprType, argX, argY);
+                OwnedHqlExpr sumxy = createValue(no_sumgroup, LINK(exprType), LINK(xy), LINK(cond));
+                OwnedHqlExpr yy = createArith(no_mul, exprType, argY, argY);
+                OwnedHqlExpr sumyy = createValue(no_sumgroup, LINK(exprType), LINK(yy), LINK(cond));
+                OwnedHqlExpr sumx = createValue(no_sumgroup, LINK(exprType), LINK(argX), LINK(cond));
+                OwnedHqlExpr sumy = createValue(no_sumgroup, LINK(exprType), LINK(argY), LINK(cond));
+                OwnedHqlExpr count = createValue(no_countgroup, LINK(defaultIntegralType), LINK(cond));
+
+                OwnedHqlExpr n1 = createArith(no_mul, exprType, sumxy, count);
+                OwnedHqlExpr n2 = createArith(no_mul, exprType, sumx, sumy);
+                OwnedHqlExpr n3 = createArith(no_sub, exprType, n1, n2);
+
+                OwnedHqlExpr n4 = createArith(no_mul, exprType, sumxx, count);
+                OwnedHqlExpr n5 = createArith(no_mul, exprType, sumx, sumx);
+                OwnedHqlExpr n6 = createArith(no_sub, exprType, n4, n5);
+
+                OwnedHqlExpr n7 = createArith(no_mul, exprType, sumyy, count);
+                OwnedHqlExpr n8 = createArith(no_mul, exprType, sumy, sumy);
+                OwnedHqlExpr n9 = createArith(no_sub, exprType, n7, n8);
+
+                OwnedHqlExpr n10 = createArith(no_mul, exprType, n6, n9);
+                OwnedHqlExpr n11 = createValue(no_sqrt, LINK(exprType), LINK(n10));
+                OwnedHqlExpr n12 = createArith(no_div, exprType, n3, n11);
+                return normalizeAggregateExpr(selector, n12, false);
+            }
+            throwUnexpected();
+
+        case no_variance:
+        case no_covariance:
+        case no_correlation:
+            throwUnexpectedOp(expr->getOperator());
+
+        case no_select:
+        case no_count:
+        case no_sum:
+        case no_max:
+        case no_min:
+        case no_ave:
+        case no_exists:
+            //Fall through - and add the expression as a field in the resulting aggregate which is later projected
+        case no_countgroup:
+        case no_sumgroup:
+        case no_maxgroup:
+        case no_mingroup:
+        case no_existsgroup:
+            {
+                ForEachItemIn(idx, assigns)
+                {
+                    IHqlExpression & cur = assigns.item(idx);
+                    if (cur.queryChild(1) == expr)
+                    {
+                        extraSelectNeeded = true;
+                        return LINK(cur.queryChild(0)); //replaceSelector(cur.queryChild(0), querySelf(), queryActiveTableSelector());
+                    }
+                }
+                IHqlExpression * targetField;
+                if (selector)
                 {
+                    targetField = LINK(selector->queryChild(1));
+                }
+                else
+                {
+                    StringBuffer temp;
+                    temp.append("_agg_").append(assigns.ordinality());
+                    targetField = createField(createIdAtom(temp.str()), expr->getType(), NULL);
                     extraSelectNeeded = true;
-                    return LINK(cur.queryChild(0)); //replaceSelector(cur.queryChild(0), querySelf(), queryActiveTableSelector());
                 }
+                fields.append(*targetField);
+                assigns.append(*createAssign(createSelectExpr(getActiveTableSelector(), LINK(targetField)), LINK(expr)));
+                return createSelectExpr(getActiveTableSelector(), LINK(targetField));
             }
-            IHqlExpression * targetField;
-            if (selector)
-            {
-                targetField = LINK(selector->queryChild(1));
-            }
-            else
+        case no_cast:
+        case no_implicitcast:
+            if (selector && canOptimizeCasts)
             {
-                StringBuffer temp;
-                temp.append("_agg_").append(assigns.ordinality());
-                targetField = createField(createIdAtom(temp.str()), expr->getType(), NULL);
-                extraSelectNeeded = true;
+                IHqlExpression * child = expr->queryChild(0);
+                if (expr->queryType()->getTypeCode() == child->queryType()->getTypeCode())
+                {
+                    IHqlExpression * ret = normalizeAggregateExpr(selector, child, false);
+                    //This should be ret==child
+                    if (ret == selector)
+                        return ret;
+                    HqlExprArray args;
+                    args.append(*ret);
+                    return expr->clone(args);
+                }
             }
-            fields.append(*targetField);
-            assigns.append(*createAssign(createSelectExpr(getActiveTableSelector(), LINK(targetField)), LINK(expr)));
-            return createSelectExpr(getActiveTableSelector(), LINK(targetField));
-        }
-    case no_cast:
-    case no_implicitcast:
-        if (selector && canOptimizeCasts)
-        {
-            IHqlExpression * child = expr->queryChild(0);
-            if (expr->queryType()->getTypeCode() == child->queryType()->getTypeCode())
+            //fallthrough...
+        default:
             {
-                IHqlExpression * ret = doNormalizeAggregateExpr(selector, child, fields, assigns, extraSelectNeeded, false);
-                //This should be ret==child
-                if (ret == selector)
-                    return ret;
                 HqlExprArray args;
-                args.append(*ret);
-                return expr->clone(args);
+                unsigned max = expr->numChildren();
+                unsigned idx;
+                bool diff = false;
+                args.ensure(max);
+                for (idx = 0; idx < max; idx++)
+                {
+                    IHqlExpression * child = expr->queryChild(idx);
+                    IHqlExpression * changed = normalizeAggregateExpr(NULL, child, false);
+                    args.append(*changed);
+                    if (child != changed)
+                        diff = true;
+                }
+                if (diff)
+                    return expr->clone(args);
+                return LINK(expr);
             }
         }
-        //fallthrough...
-    default:
-        {
-            HqlExprArray args;
-            unsigned max = expr->numChildren();
-            unsigned idx;
-            bool diff = false;
-            args.ensure(max);
-            for (idx = 0; idx < max; idx++)
-            {
-                IHqlExpression * child = expr->queryChild(idx);
-                IHqlExpression * changed = doNormalizeAggregateExpr(NULL, child, fields, assigns, extraSelectNeeded, false);
-                args.append(*changed);
-                if (child != changed)
-                    diff = true;
-            }
-            if (diff)
-                return expr->clone(args);
-            return LINK(expr);
-        }
     }
-}
-
-IHqlExpression * doNormalizeAggregateExpr(IHqlExpression * selector, IHqlExpression * expr, HqlExprArray & fields, HqlExprArray & assigns, bool & extraSelectNeeded, bool canOptimizeCasts)
-{
-    IHqlExpression * match = static_cast<IHqlExpression *>(expr->queryTransformExtra());
-    if (match)
-        return LINK(match);
-    IHqlExpression * ret = evalNormalizeAggregateExpr(selector, expr, fields, assigns, extraSelectNeeded, canOptimizeCasts);
-    expr->setTransformExtra(ret);
-    return ret;
-}
 
+protected:
+    TransformMutexBlock block;
+    IHqlExpression * rootSelector;
+    HqlExprArray & fields;
+    HqlExprArray & assigns;
+    bool & extraSelectNeeded;
+};
 
-IHqlExpression * normalizeAggregateExpr(IHqlExpression * selector, IHqlExpression * expr, HqlExprArray & fields, HqlExprArray & assigns, bool & extraSelectNeeded, bool canOptimizeCasts)
+IHqlExpression * normalizeAggregateExpr(IHqlExpression * tableSelector, IHqlExpression * selector, IHqlExpression * expr, HqlExprArray & fields, HqlExprArray & assigns, bool & extraSelectNeeded, bool canOptimizeCasts)
 {
-    TransformMutexBlock block;
-    return doNormalizeAggregateExpr(selector, expr, fields, assigns, extraSelectNeeded, canOptimizeCasts);
+    AggregateNormalizer normalizer(tableSelector, fields, assigns, extraSelectNeeded);
+    return normalizer.normalizeAggregateExpr(selector, expr, canOptimizeCasts);
 }
 
 
@@ -3545,11 +3558,13 @@ IHqlExpression * ThorHqlTransformer::normalizeTableToAggregate(IHqlExpression *
         IHqlExpression * assign=transform->queryChild(idx);
         IHqlExpression * cur = assign->queryChild(0);
         IHqlExpression * src = assign->queryChild(1);
-        IHqlExpression * mapped = normalizeAggregateExpr(cur, src, aggregateFields, aggregateAssigns, extraSelectNeeded, canOptimizeCasts);
-
-        if (mapped == src)
+        IHqlExpression * mapped;
+        if (src->isGroupAggregateFunction())
+        {
+            mapped = normalizeAggregateExpr(dataset->queryNormalizedSelector(), cur, src, aggregateFields, aggregateAssigns, extraSelectNeeded, canOptimizeCasts);
+        }
+        else
         {
-            mapped->Release();
             mapped = replaceSelector(cur, self, queryActiveTableSelector());
 
             // Not an aggregate - must be an expression that is used in the grouping

+ 1 - 0
testing/regress/ecl/aggds1.ecl

@@ -47,4 +47,5 @@ sq := setup.sq(multiPart);
     output(exists(sq.SimplePersonBookDs));
     output(exists(sq.SimplePersonBookDs(forename = 'Gavin')));
     output(exists(sq.SimplePersonBookDs(forename = 'Joshua')));
+    output(sort(table(sq.SimplePersonBookDs, { aage, scale := count(group, (forename != 'Gavin')) / aage }, aage), aage));
 );

+ 11 - 0
testing/regress/ecl/key/aggds1.xml

@@ -22,3 +22,14 @@
 <Dataset name='Result 8'>
  <Row><Result_8>false</Result_8></Row>
 </Dataset>
+<Dataset name='Result 9'>
+ <Row><aage>5</aage><scale>0.2</scale></Row>
+ <Row><aage>34</aage><scale>0.02941176470588235</scale></Row>
+ <Row><aage>35</aage><scale>0.0</scale></Row>
+ <Row><aage>49</aage><scale>0.02040816326530612</scale></Row>
+ <Row><aage>50</aage><scale>0.02</scale></Row>
+ <Row><aage>68</aage><scale>0.01470588235294118</scale></Row>
+ <Row><aage>78</aage><scale>0.01282051282051282</scale></Row>
+ <Row><aage>83</aage><scale>0.01204819277108434</scale></Row>
+ <Row><aage>99</aage><scale>0.0202020202020202</scale></Row>
+</Dataset>