Sfoglia il codice sorgente

Merge pull request #9953 from mckellyln/ssl_listener

HPCC-17260 Roxie SSL listener for batch queries

Reviewed-By: Anthony Fishbeck <anthony.fishbeck@lexisnexis.com>
Reviewed-By: Richard Chapman <rchapman@hpccsystems.com>
Richard Chapman 8 anni fa
parent
commit
397b6fb5da

+ 1 - 0
roxie/ccd/CMakeLists.txt

@@ -88,6 +88,7 @@ include_directories (
          ${CMAKE_BINARY_DIR}/oss
          ${HPCC_SOURCE_DIR}/dali/ft
          ${HPCC_SOURCE_DIR}/system/security/shared
+         ${HPCC_SOURCE_DIR}/system/security/securesocket
     )
 
 ADD_DEFINITIONS( -D_USRDLL -DCCD_EXPORTS -DSTARTQUERY_EXPORTS )

+ 38 - 1
roxie/ccd/ccdmain.cpp

@@ -1043,6 +1043,8 @@ int STARTQUERY_API start_query(int argc, const char *argv[])
         if (!localSlave)
             openMulticastSocket();
 
+        StringBuffer certFileName;
+        StringBuffer keyFileName;
         setDaliServixSocketCaching(true);  // enable daliservix caching
         loadPlugins();
         createDelayedReleaser();
@@ -1118,6 +1120,7 @@ int STARTQUERY_API start_query(int argc, const char *argv[])
                 unsigned numThreads = roxieFarm.getPropInt("@numThreads", numServerThreads);
                 unsigned port = roxieFarm.getPropInt("@port", ROXIE_SERVER_PORT);
                 unsigned requestArrayThreads = roxieFarm.getPropInt("@requestArrayThreads", 5);
+                // NOTE: farmer name [@name=] is not copied into topology
                 const IpAddress &ip = getNodeAddress(myNodeIndex);
                 if (!roxiePort)
                 {
@@ -1129,10 +1132,44 @@ int STARTQUERY_API start_query(int argc, const char *argv[])
                 if (port)
                 {
                     const char *protocol = roxieFarm.queryProp("@protocol");
+                    const char *passPhrase = nullptr;
+                    const char *certFile = nullptr;
+                    const char *keyFile = nullptr;
+                    if (protocol && streq(protocol, "ssl"))
+                    {
+#ifdef _USE_OPENSSL
+                        certFile = roxieFarm.queryProp("@certificateFileName");
+                        if (!certFile)
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing certificateFileName tag", port);
+                        if (isAbsolutePath(certFile))
+                            certFileName.append(certFile);
+                        else
+                            certFileName.append(codeDirectory.str()).append(certFile);
+                        if (!checkFileExists(certFileName.str()))
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing certificateFile (%s)", port, certFileName.str());
+
+                        keyFile =  roxieFarm.queryProp("@privateKeyFileName");
+                        if (!keyFile)
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing privateKeyFileName tag", port);
+                        if (isAbsolutePath(keyFile))
+                            keyFileName.append(keyFile);
+                        else
+                            keyFileName.append(codeDirectory.str()).append(keyFile);
+                        if (!checkFileExists(keyFileName.str()))
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing privateKeyFile (%s)", port, keyFileName.str());
+
+                        passPhrase = roxieFarm.queryProp("@passphrase");
+                        if (isEmptyString(passPhrase))
+                            passPhrase = nullptr;
+#else
+                        WARNLOG("Skipping Roxie SSL Farm Listener on port %d : OpenSSL disabled in build", port);
+                        continue;
+#endif
+                    }
                     const char *soname =  roxieFarm.queryProp("@so");
                     const char *config  = roxieFarm.queryProp("@config");
                     Owned<IHpccProtocolPlugin> protocolPlugin = ensureProtocolPlugin(*protocolCtx, soname);
-                    roxieServer.setown(protocolPlugin->createListener(protocol ? protocol : "native", createRoxieProtocolMsgSink(ip, port, numThreads, suspended), port, listenQueue, config));
+                    roxieServer.setown(protocolPlugin->createListener(protocol ? protocol : "native", createRoxieProtocolMsgSink(ip, port, numThreads, suspended), port, listenQueue, config, certFileName.str(), keyFileName.str(), passPhrase));
                 }
                 else
                     roxieServer.setown(createRoxieWorkUnitListener(numThreads, suspended));

+ 55 - 8
roxie/ccd/ccdprotocol.cpp

@@ -22,10 +22,11 @@
 #include "roxie.hpp"
 #include "roxiehelper.hpp"
 #include "ccdprotocol.hpp"
+#include "securesocket.hpp"
 
 //================================================================================================================================
 
-IHpccProtocolListener *createProtocolListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue);
+IHpccProtocolListener *createProtocolListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue, const char *certFile, const char *keyFile, const char *passPhrase);
 
 class CHpccProtocolPlugin : implements IHpccProtocolPlugin, public CInterface
 {
@@ -54,9 +55,9 @@ public:
         trapTooManyActiveQueries = ctx.ctxGetPropBool("@trapTooManyActiveQueries", true);
         numRequestArrayThreads = ctx.ctxGetPropInt("@requestArrayThreads", 5);
     }
-    IHpccProtocolListener *createListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue, const char *config)
+    IHpccProtocolListener *createListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue, const char *config, const char *certFile=nullptr, const char *keyFile=nullptr, const char *passPhrase=nullptr)
     {
-        return createProtocolListener(protocol, sink, port, listenQueue);
+        return createProtocolListener(protocol, sink, port, listenQueue, certFile, keyFile, passPhrase);
     }
 public:
     StringArray targetNames;
@@ -218,14 +219,23 @@ class ProtocolSocketListener : public ProtocolListener
     unsigned listenQueue;
     Owned<ISocket> socket;
     SocketEndpoint ep;
+    const char *protocol;
+    const char *certFile;
+    const char *keyFile;
+    const char *passPhrase;
+    Owned<ISecureSocketContext> secureContext;
 
 public:
-    ProtocolSocketListener(IHpccProtocolMsgSink *_sink, unsigned _port, unsigned _listenQueue)
+    ProtocolSocketListener(IHpccProtocolMsgSink *_sink, unsigned _port, unsigned _listenQueue, const char *_protocol, const char *_certFile, const char *_keyFile, const char *_passPhrase)
       : ProtocolListener(_sink)
     {
         port = _port;
         listenQueue = _listenQueue;
         ep.set(port, queryHostIP());
+        protocol = _protocol;
+        certFile = _certFile;
+        keyFile = _keyFile;
+        passPhrase = _passPhrase;
     }
 
     IHpccProtocolMsgSink *queryMsgSink()
@@ -269,11 +279,48 @@ public:
         started.signal();
         while (running)
         {
-            ISocket *client = socket->accept(true);
+            Owned<ISocket> client = socket->accept(true);
+            Owned<ISecureSocket> ssock;
             if (client)
             {
+                if (protocol && streq(protocol, "ssl"))
+                {
+#ifdef _USE_OPENSSL
+                    try
+                    {
+                        if (!secureContext)
+                            secureContext.setown(createSecureSocketContextEx(certFile, keyFile, passPhrase, ServerSocket));
+                        ssock.setown(secureContext->createSecureSocket(client.getClear()));
+                        int status = ssock->secure_accept();
+                        if (status < 0)
+                        {
+                            // secure_accept may also DBGLOG() errors ...
+                            WARNLOG("ProtocolSocketListener failure to establish secure connection");
+                            continue;
+                        }
+                    }
+                    catch (IException *E)
+                    {
+                        StringBuffer s;
+                        E->errorMessage(s);
+                        WARNLOG("%s", s.str());
+                        E->Release();
+                        continue;
+                    }
+                    catch (...)
+                    {
+                        StringBuffer s;
+                        WARNLOG("ProtocolSocketListener failure to establish secure connection");
+                        continue;
+                    }
+                    client.setown(ssock.getClear());
+#else
+                    WARNLOG("ProtocolSocketListener failure to establish secure connection: OpenSSL disabled in build");
+                    continue;
+#endif
+                }
                 client->set_linger(-1);
-                pool->start(client);
+                pool->start(client.getClear());
             }
         }
         DBGLOG("ProtocolSocketListener closed query socket");
@@ -2022,11 +2069,11 @@ void ProtocolSocketListener::runOnce(const char *query)
     p->runOnce(query);
 }
 
-IHpccProtocolListener *createProtocolListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue)
+IHpccProtocolListener *createProtocolListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue, const char *certFile=nullptr, const char *keyFile=nullptr, const char *passPhrase=nullptr)
 {
     if (traceLevel)
         DBGLOG("Creating Roxie socket listener, protocol %s, pool size %d, listen queue %d%s", protocol, sink->getPoolSize(), listenQueue, sink->getIsSuspended() ? " SUSPENDED":"");
-    return new ProtocolSocketListener(sink, port, listenQueue);
+    return new ProtocolSocketListener(sink, port, listenQueue, protocol, certFile, keyFile, passPhrase);
 }
 
 extern IHpccProtocolPlugin *loadHpccProtocolPlugin(IHpccProtocolPluginContext *ctx, IActiveQueryLimiterFactory *_limiterFactory)

+ 1 - 1
roxie/ccd/hpccprotocol.hpp

@@ -136,7 +136,7 @@ interface IActiveQueryLimiterFactory : extends IInterface
 
 interface IHpccProtocolPlugin : extends IInterface
 {
-    virtual IHpccProtocolListener *createListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue, const char *config)=0;
+    virtual IHpccProtocolListener *createListener(const char *protocol, IHpccProtocolMsgSink *sink, unsigned port, unsigned listenQueue, const char *config, const char *certFile=nullptr, const char *keyFile=nullptr, const char *passPhrase=nullptr)=0;
 };
 
 extern IHpccProtocolPlugin *loadHpccProtocolPlugin(IHpccProtocolPluginContext *ctx, IActiveQueryLimiterFactory *limiterFactory);

+ 28 - 25
system/security/securesocket/securesocket.cpp

@@ -164,12 +164,15 @@ public:
     virtual void   read(void* buf, size32_t size)
     {
         size32_t size_read;
-        readTimeout(buf, size, size, size_read, 0, false);
+        // MCK - this was:
+        // readTimeout(buf, size, size, size_read, 0, false);
+        // but that is essentially a non-blocking read() and we want a blocking read() ...
+        readTimeout(buf, 0, size, size_read, WAIT_FOREVER, false);
     }
 
     virtual size32_t get_max_send_size()
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::get_max_send_size: not implemented");
     }
 
     //
@@ -177,7 +180,7 @@ public:
     //
     virtual ISocket* accept(bool allowcancel=false) // not needed for UDP
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::accept: not implemented");
     }
 
     //
@@ -185,7 +188,7 @@ public:
     // 
     virtual int wait_write(unsigned timeout)
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::wait_write: not implemented");
     }
 
     //
@@ -194,14 +197,14 @@ public:
     //
     virtual bool set_nonblock(bool on) // returns old state
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_nonblock: not implemented");
     }
 
     // enable 'nagling' - small packet coalescing (implies delayed transmission)
     //
     virtual bool set_nagle(bool on) // returns old state
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_nagle: not implemented");
     }
 
 
@@ -209,7 +212,7 @@ public:
     //
     virtual void set_linger(int lingersecs)  
     {
-        throw MakeStringException(-1, "not implemented");
+        m_socket->set_linger(lingersecs);
     }
 
 
@@ -218,7 +221,7 @@ public:
     //
     virtual void  cancel_accept() // not needed for UDP
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::cancel_accept: not implemented");
     }
 
     //
@@ -258,12 +261,12 @@ public:
     //
     virtual bool connectionless() // true if accept need not be called (i.e. UDP)
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::connectionless: not implemented");
     }
 
     virtual void set_return_addr(int port,const char *name) // used for UDP servers only
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_return_addr: not implemented");
     }
 
     // Block functions 
@@ -274,7 +277,7 @@ public:
                             unsigned timeout=0 // timeout in msecs (0 for no timeout)
                   ) 
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_block_mode: not implemented");
     }
 
 
@@ -284,12 +287,12 @@ public:
                             size32_t sz          // size to send (0 for eof)
                   )
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::send_block: not implemented");
     }
 
     virtual size32_t receive_block_size ()     // get size of next block (always must call receive_block after) 
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::receive_block_size: not implemented");
     }
 
     virtual size32_t receive_block(
@@ -298,7 +301,7 @@ public:
                                                // if less than block size truncates block
                   )
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::receive_block: not implemented");
     }
 
     virtual void  close()
@@ -321,59 +324,59 @@ public:
 
     virtual size32_t write_multiple(unsigned num,const void **buf, size32_t *size)
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::write_multiple: not implemented");
     }
 
     virtual size32_t get_send_buffer_size() // get OS send buffer
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::get_send_buffer_size: not implemented");
     }
 
     void set_send_buffer_size(size32_t sz)  // set OS send buffer size
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_send_buffer_size: not implemented");
     }
 
     bool join_multicast_group(SocketEndpoint &ep)   // for udp multicast
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::join_multicast_group: not implemented");
         return false;
     }
 
     bool leave_multicast_group(SocketEndpoint &ep)  // for udp multicast
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::leave_multicast_group: not implemented");
         return false;
     }
 
     void set_ttl(unsigned _ttl)   // set ttl
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_ttl: not implemented");
     }
 
     size32_t get_receive_buffer_size()  // get OS send buffer
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::get_receive_buffer_size: not implemented");
     }
 
     void set_receive_buffer_size(size32_t sz)   // set OS send buffer size
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_receive_buffer_size: not implemented");
     }
 
     virtual void set_keep_alive(bool set) // set option SO_KEEPALIVE
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::set_keep_alive: not implemented");
     }
 
     virtual size32_t udp_write_to(const SocketEndpoint &ep, void const* buf, size32_t size)
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::udp_write_to: not implemented");
     }
 
     virtual bool check_connection()
     {
-        throw MakeStringException(-1, "not implemented");
+        throw MakeStringException(-1, "CSecureSocket::check_connection: not implemented");
     }
 
     virtual bool isSecure() const override

+ 8 - 0
tools/testsocket/CMakeLists.txt

@@ -26,6 +26,7 @@ project( testsocket )
 include_directories (            
          ./../../system/jlib
          ./../../system/include 
+         ./../../system/security/securesocket
     )
 
 ADD_DEFINITIONS( -D_CONSOLE )
@@ -35,6 +36,13 @@ HPCC_ADD_EXECUTABLE ( testsocket ${SRCS} )
 target_link_libraries ( testsocket
          jlib
          )
+
+IF (USE_OPENSSL)
+    target_link_libraries ( testsocket
+         securesocket
+         )
+ENDIF()
+
 install ( TARGETS testsocket RUNTIME DESTINATION ${EXEC_DIR} )
 
 if ( PLATFORM )

+ 43 - 9
tools/testsocket/testsocket.cpp

@@ -25,9 +25,11 @@
 #include "jdebug.hpp"
 #include "jthread.hpp"
 #include "jfile.hpp"
+#include "securesocket.hpp"
 
 bool abortEarly = false;
 bool forceHTTP = false;
+bool useSSL = false;
 bool abortAfterFirst = false;
 bool echoResults = false;
 bool saveResults = true;
@@ -51,8 +53,10 @@ unsigned runningQueries;
 unsigned multiThreadMax;
 unsigned maxLineSize = 10000000;
 
-ISocket *persistSocket;
-bool persistConnections;
+Owned<ISocket> persistSocket;
+bool persistConnections = false;
+Owned<ISecureSocketContext> persistSecureContext;
+Owned<ISecureSocket> persistSSock;
 
 int repeats = 0;
 StringBuffer queryPrefix;
@@ -370,7 +374,8 @@ int ReceiveThread::run()
 
 int doSendQuery(const char * ip, unsigned port, const char * base)
 {
-    ISocket * socket;
+    Owned<ISocket> socket;
+    Owned<ISecureSocketContext> secureContext;
     __int64 starttime, endtime;
     StringBuffer ipstr;
     try
@@ -404,15 +409,40 @@ int doSendQuery(const char * ip, unsigned port, const char * base)
         starttime= get_cycles_now();
         if (persistConnections)
         {
-            if (!persistSocket) {
+            if (!persistSocket)
+            {
                 SocketEndpoint ep(ip,port);
-                persistSocket = ISocket::connect_timeout(ep, 1000);
+                persistSocket.setown(ISocket::connect_timeout(ep, 1000));
+                if (useSSL)
+                {
+#ifdef _USE_OPENSSL
+                    if (!persistSecureContext)
+                        persistSecureContext.setown(createSecureSocketContext(ClientSocket));
+                    persistSSock.setown(persistSecureContext->createSecureSocket(persistSocket.getClear()));
+                    persistSSock->secure_connect();
+                    persistSocket.setown(persistSSock.getClear());
+#else
+                    throw MakeStringException(-1, "OpenSSL disabled in build");
+#endif
+                }
             }
             socket = persistSocket;
         }
-        else {
+        else
+        {
             SocketEndpoint ep(ip,port);
-            socket = ISocket::connect_timeout(ep,1000);
+            socket.setown(ISocket::connect_timeout(ep, 1000));
+            if (useSSL)
+            {
+#ifdef _USE_OPENSSL
+                secureContext.setown(createSecureSocketContext(ClientSocket));
+                Owned<ISecureSocket> ssock = secureContext->createSecureSocket(socket.getClear());
+                ssock->secure_connect();
+                socket.setown(ssock.getClear());
+#else
+                throw MakeStringException(1, "OpenSSL disabled in build");
+#endif
+            }
         }
     }
     catch(IException * e)
@@ -562,7 +592,6 @@ int doSendQuery(const char * ip, unsigned port, const char * base)
     if (!persistConnections)
     {
         socket->close();
-        socket->Release();
     }
     return 0;
 }
@@ -622,6 +651,7 @@ void usage(int exitCode)
     printf("  -rl       roxie logfile mode\n");
     printf("  -s        add stars to indicate transfer packets\n");
     printf("  -ss       suppress XML Status messages to screen (always suppressed from tracefile)\n");
+    printf("  -ssl      use ssl\n");
     printf("  -td       add debug timing statistics to trace\n");
     printf("  -tf       add full timing statistics to trace\n");
     printf("  -time     add timing to trace\n");
@@ -684,6 +714,11 @@ int main(int argc, char **argv)
             forceHTTP = true;
             ++arg;
         }
+        else if (stricmp(argv[arg], "-ssl") == 0)
+        {
+            useSSL = true;
+            ++arg;
+        }
         else if (stricmp(argv[arg], "-") == 0)
         {
             fromStdIn = true;
@@ -946,7 +981,6 @@ int main(int argc, char **argv)
         int sendlen=0;
         persistSocket->write(&sendlen, sizeof(sendlen));
         persistSocket->close();
-        persistSocket->Release();
     }
 
     endtime = get_cycles_now();