Преглед изворни кода

HPCC-21222 Datasets passed to R dataframes should use Vectors rather than lists

Signed-off-by: Richard Chapman <rchapman@hpccsystems.com>
Richard Chapman пре 6 година
родитељ
комит
18622e2144
3 измењених фајлова са 108 додато и 47 уклоњено
  1. 6 6
      cmake_modules/FindR.cmake
  2. 101 37
      plugins/Rembed/Rembed.cpp
  3. 1 4
      testing/regress/ecl/modelingWithR.ecl

+ 6 - 6
cmake_modules/FindR.cmake

@@ -32,13 +32,13 @@ IF (NOT R_FOUND)
     SET (RInside_lib "RInside")
   ENDIF()
 
-  FIND_PATH(R_INCLUDE_DIR R.h PATHS /usr/lib /usr/lib64 /usr/share /usr/local/lib /usr/local/lib64 PATH_SUFFIXES R/include)
-  FIND_PATH(RCPP_INCLUDE_DIR Rcpp.h PATHS /usr/lib /usr/lib64 /usr/share /usr/local/lib /usr/local/lib64 PATH_SUFFIXES R/library/Rcpp/include/ R/site-library/Rcpp/include/)
-  FIND_PATH(RINSIDE_INCLUDE_DIR RInside.h PATHS /usr/lib /usr/lib64 /usr/share /usr/local/lib /usr/local/lib64 PATH_SUFFIXES R/library/RInside/include  R/site-library/RInside/include)
+  FIND_PATH(R_INCLUDE_DIR R.h PATHS /usr/lib /usr/lib64 /usr/share /usr/local/lib /usr/local/lib64 /Library/Frameworks/R.framework/Versions/3.5/ PATH_SUFFIXES R/include Resources/include)
+  FIND_PATH(RCPP_INCLUDE_DIR Rcpp.h PATHS /usr/lib /usr/lib64 /usr/share /usr/local/lib /usr/local/lib64 /Library/Frameworks/R.framework/Versions/3.5/ PATH_SUFFIXES R/library/Rcpp/include/ R/site-library/Rcpp/include/ Resources/library/Rcpp/include/)
+  FIND_PATH(RINSIDE_INCLUDE_DIR RInside.h PATHS /usr/lib /usr/lib64 /usr/share /usr/local/lib /usr/local/lib64 /Library/Frameworks/R.framework/Versions/3.5/ PATH_SUFFIXES R/library/RInside/include  R/site-library/RInside/include /Resources/library/RInside/include)
 
-  FIND_LIBRARY (R_LIBRARY NAMES ${R_lib}  PATHS /usr/lib /usr/share /usr/lib64 /usr/local/lib /usr/local/lib64 PATH_SUFFIXES R/lib)
-  FIND_LIBRARY (RCPP_LIBRARY NAMES ${Rcpp_lib} PATHS /usr/lib /usr/share /usr/lib64 /usr/local/lib /usr/local/lib64 PATH_SUFFIXES R/library/Rcpp/lib/ R/site-library/Rcpp/lib/)
-  FIND_LIBRARY (RINSIDE_LIBRARY NAMES ${RInside_lib} PATHS /usr/lib /usr/share /usr/lib64 /usr/local/lib /usr/local/lib64 PATH_SUFFIXES R/library/RInside/lib/ R/site-library/RInside/lib/)
+  FIND_LIBRARY (R_LIBRARY NAMES ${R_lib}  PATHS /usr/lib /usr/share /usr/lib64 /usr/local/lib /usr/local/lib64 /Library/Frameworks/R.framework/Versions/3.5/ PATH_SUFFIXES R/lib Resources/lib)
+  FIND_LIBRARY (RCPP_LIBRARY NAMES ${Rcpp_lib} PATHS /usr/lib /usr/share /usr/lib64 /usr/local/lib /usr/local/lib64 /Library/Frameworks/R.framework/Versions/3.5/ PATH_SUFFIXES R/library/Rcpp/lib/ R/site-library/Rcpp/lib/)
+  FIND_LIBRARY (RINSIDE_LIBRARY NAMES ${RInside_lib} PATHS /usr/lib /usr/share /usr/lib64 /usr/local/lib /usr/local/lib64 /Library/Frameworks/R.framework/Versions/3.5/ PATH_SUFFIXES R/library/RInside/lib/ R/site-library/RInside/lib/ Resources/library/RInside/lib/)
 
   IF (RCPP_LIBRARY STREQUAL "RCPP_LIBRARY-NOTFOUND")
     SET (RCPP_LIBRARY "")    # Newer versions of Rcpp are header-only, with no associated library.

+ 101 - 37
plugins/Rembed/Rembed.cpp

@@ -262,10 +262,47 @@ static Rcpp::DataFrame createDataFrame(const RtlTypeInfo *typeInfo, unsigned num
     getFieldNames(namevec, typeInfo);
     Rcpp::List frame(namevec.length()); // Number of columns
     frame.attr("names") = namevec;
+    const RtlFieldInfo * const *fields = typeInfo->queryFields();
     for (int i=0; i< frame.length(); i++)
     {
-        Rcpp::List column(numRows);
-        frame[i] = column;
+        assertex(*fields);
+        const RtlFieldInfo *child = *fields;
+        switch (child->type->getType())
+        {
+        case type_boolean:
+        {
+            Rcpp::LogicalVector column(numRows);
+            frame[i] = column;
+            break;
+        }
+        case type_int:
+        {
+            Rcpp::IntegerVector column(numRows);
+            frame[i] = column;
+            break;
+        }
+        case type_real:
+        case type_decimal:
+        {
+            Rcpp::NumericVector column(numRows);
+            frame[i] = column;
+            break;
+        }
+        case type_string:
+        case type_varstring:
+        {
+            Rcpp::StringVector column(numRows);
+            frame[i] = column;
+            break;
+        }
+        default:
+        {
+            Rcpp::List column(numRows);
+            frame[i] = column;
+            break;
+        }
+        }
+        fields++;
     }
     Rcpp::StringVector row_names(numRows);
     for (unsigned row = 0; row < numRows; row++)
@@ -310,22 +347,14 @@ public:
         if (inSet)
             theStringSet[setIndex++] = s;
         else
-        {
-            unsigned r;
-            Rcpp::List l = stack.tos().cell(r);
-            l[r] = s;
-        }
+            stack.tos().setString(s);
     }
     virtual void processBool(bool value, const RtlFieldInfo * field) override
     {
         if (inSet)
             theBoolSet[setIndex++] = value;
         else
-        {
-            unsigned r;
-            Rcpp::List l = stack.tos().cell(r);
-            l[r] = value;
-        }
+            stack.tos().setBool(value);
     }
     virtual void processData(unsigned len, const void *value, const RtlFieldInfo * field) override
     {
@@ -341,33 +370,21 @@ public:
         if (inSet)
             theIntSet[setIndex++] = (long int) value;
         else
-        {
-            unsigned r;
-            Rcpp::List l = stack.tos().cell(r);
-            l[r] = (long int) value;  // Rcpp does not support int64
-        }
+            stack.tos().setInt(value);
     }
     virtual void processUInt(unsigned __int64 value, const RtlFieldInfo * field) override
     {
         if (inSet)
             theIntSet[setIndex++] = (unsigned long int) value;
         else
-        {
-            unsigned r;
-            Rcpp::List l = stack.tos().cell(r);
-            l[r] = (unsigned long int) value;  // Rcpp does not support int64
-        }
+            stack.tos().setUInt(value);
     }
     virtual void processReal(double value, const RtlFieldInfo * field) override
     {
         if (inSet)
             theRealSet[setIndex++] = value;
         else
-        {
-            unsigned r;
-            Rcpp::List l = stack.tos().cell(r);
-            l[r] = value;
-        }
+            stack.tos().setReal(value);
     }
     virtual void processDecimal(const void *value, unsigned digits, unsigned precision, const RtlFieldInfo * field) override
     {
@@ -376,11 +393,7 @@ public:
         if (inSet)
             theRealSet[setIndex++] = val.getReal();
         else
-        {
-            unsigned r;
-            Rcpp::List l = stack.tos().cell(r);
-            l[r] = val.getReal();
-        }
+            stack.tos().setReal(val.getReal());
     }
     virtual void processUDecimal(const void *value, unsigned digits, unsigned precision, const RtlFieldInfo * field) override
     {
@@ -389,11 +402,7 @@ public:
         if (inSet)
             theRealSet[setIndex++] = val.getReal();
         else
-        {
-            unsigned r;
-            Rcpp::List l = stack.tos().cell(r);
-            l[r] = val.getReal();
-        }
+            stack.tos().setReal(val.getReal());
     }
     virtual void processUnicode(unsigned len, const UChar *value, const RtlFieldInfo * field) override
     {
@@ -490,6 +499,11 @@ protected:
     interface IDataListPosition : public IInterface
     {
         virtual Rcpp::List cell(unsigned &row) = 0;
+        virtual void setBool(bool value) = 0;
+        virtual void setInt(__int64 value) = 0;
+        virtual void setUInt(unsigned __int64 value) = 0;
+        virtual void setReal(double value) = 0;
+        virtual void setString(const std::string &s) = 0;
         virtual void nextRow() = 0;
         virtual bool isNestedRow(const RtlFieldInfo *_field) const = 0;
     };
@@ -503,6 +517,36 @@ protected:
             curCell = frame[colIdx++];
             return curCell;
         }
+        virtual void setBool(bool value) override
+        {
+            unsigned row = rowIdx-1;        // nextRow is called before the first row, so rowIdx is 1-based
+            Rcpp::LogicalVector l = frame[colIdx++];
+            l[row] = value;
+        }
+        virtual void setInt(__int64 value) override
+        {
+            unsigned row = rowIdx-1;
+            Rcpp::IntegerVector l = frame[colIdx++];
+            l[row] = (long int) value;  // Rcpp does not support int64
+        }
+        virtual void setUInt(unsigned __int64 value) override
+        {
+            unsigned row = rowIdx-1;
+            Rcpp::IntegerVector l = frame[colIdx++];
+            l[row] = (unsigned long int) value;  // Rcpp does not support int64
+        }
+        virtual void setReal(double value) override
+        {
+            unsigned row = rowIdx-1;
+            Rcpp::NumericVector l = frame[colIdx++];
+            l[row] = value;
+        }
+        virtual void setString(const std::string &value) override
+        {
+            unsigned row = rowIdx-1;
+            Rcpp::StringVector l = frame[colIdx++];
+            l[row] = value;
+        }
         virtual void nextRow() override
         {
             rowIdx++;
@@ -529,6 +573,26 @@ protected:
             row = colIdx++;
             return list;
         }
+        virtual void setBool(bool value) override
+        {
+            list[colIdx++] = value;
+        }
+        virtual void setInt(__int64 value) override
+        {
+            list[colIdx++] = (long int) value;  // Rcpp does not support int64
+        }
+        virtual void setUInt(unsigned __int64 value) override
+        {
+            list[colIdx++] = (unsigned long int) value;  // Rcpp does not support int64
+        }
+        virtual void setReal(double value) override
+        {
+            list[colIdx++] = value;
+        }
+        virtual void setString(const std::string &value) override
+        {
+            list[colIdx++] = value;
+        }
         virtual void nextRow() override
         {
             colIdx = 0;

+ 1 - 4
testing/regress/ecl/modelingWithR.ecl

@@ -60,7 +60,7 @@ irisFitRec := RECORD
   STRING100 fit;
 END;
 
-// We return a record containing predicitons and fit
+// We return a record containing predictions and fit
 resultRec := RECORD
   DATASET(irisPredictedRec) predictions;
   DATASET(irisFitRec) fit;
@@ -68,9 +68,6 @@ END;
 
 // Main embedded R script for building and scoring models
 resultRec runAnalyses(DATASET(irisRec) ds) := EMBED(R)
-
-  ds <- data.frame(lapply(ds, unlist))
-  
   # Build models
   lmMdl <- lm(label ~ ., data = ds) # same as glm with Gaussian family
   glmMdl <- glm(label ~ ., family = binomial, data = ds)