Browse Source

HPCC-17260 Roxie SSL listener for batch queries

Signed-off-by: Mark Kelly <mark.kelly@lexisnexisrisk.com>
Mark Kelly 8 years ago
parent
commit
f6539ea9c5
3 changed files with 44 additions and 108 deletions
  1. 14 26
      roxie/ccd/ccdmain.cpp
  2. 13 17
      roxie/ccd/ccdprotocol.cpp
  3. 17 65
      tools/testsocket/testsocket.cpp

+ 14 - 26
roxie/ccd/ccdmain.cpp

@@ -1043,8 +1043,8 @@ int STARTQUERY_API start_query(int argc, const char *argv[])
         if (!localSlave)
             openMulticastSocket();
 
-        StringBuffer certFileBuf;
-        StringBuffer keyFileBuf;
+        StringBuffer certFileName;
+        StringBuffer keyFileName;
         setDaliServixSocketCaching(true);  // enable daliservix caching
         loadPlugins();
         createDelayedReleaser();
@@ -1140,38 +1140,26 @@ int STARTQUERY_API start_query(int argc, const char *argv[])
 #ifdef _USE_OPENSSL
                         certFile = roxieFarm.queryProp("@certificateFileName");
                         if (!certFile)
-                        {
-                            WARNLOG("Skipping Roxie SSL Farm Listener on port %d due to missing certificateFileName tag", port);
-                            continue;
-                        }
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing certificateFileName tag", port);
                         if (isAbsolutePath(certFile))
-                            certFileBuf.append(certFile);
+                            certFileName.append(certFile);
                         else
-                            certFileBuf.append(codeDirectory.str()).append(certFile);
-                        if (!checkFileExists(certFileBuf.str()))
-                        {
-                            WARNLOG("Skipping Roxie SSL Farm Listener on port %d due to missing certificateFile (%s)", port, certFileBuf.str());
-                            continue;
-                        }
+                            certFileName.append(codeDirectory.str()).append(certFile);
+                        if (!checkFileExists(certFileName.str()))
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing certificateFile (%s)", port, certFileName.str());
 
                         keyFile =  roxieFarm.queryProp("@privateKeyFileName");
                         if (!keyFile)
-                        {
-                            WARNLOG("Skipping Roxie SSL Farm Listener on port %d due to missing privateKeyFileName tag", port);
-                            continue;
-                        }
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing privateKeyFileName tag", port);
                         if (isAbsolutePath(keyFile))
-                            keyFileBuf.append(keyFile);
+                            keyFileName.append(keyFile);
                         else
-                            keyFileBuf.append(codeDirectory.str()).append(keyFile);
-                        if (!checkFileExists(keyFileBuf.str()))
-                        {
-                            WARNLOG("Skipping Roxie SSL Farm Listener on port %d due to missing privateKeyFile (%s)", port, keyFileBuf.str());
-                            continue;
-                        }
+                            keyFileName.append(codeDirectory.str()).append(keyFile);
+                        if (!checkFileExists(keyFileName.str()))
+                            throw MakeStringException(ROXIE_FILE_ERROR, "Roxie SSL Farm Listener on port %d missing privateKeyFile (%s)", port, keyFileName.str());
 
                         passPhrase = roxieFarm.queryProp("@passphrase");
-                        if (passPhrase && (int)strlen(passPhrase) == 0)
+                        if (isEmptyString(passPhrase))
                             passPhrase = nullptr;
 #else
                         WARNLOG("Skipping Roxie SSL Farm Listener on port %d : OpenSSL disabled in build", port);
@@ -1181,7 +1169,7 @@ int STARTQUERY_API start_query(int argc, const char *argv[])
                     const char *soname =  roxieFarm.queryProp("@so");
                     const char *config  = roxieFarm.queryProp("@config");
                     Owned<IHpccProtocolPlugin> protocolPlugin = ensureProtocolPlugin(*protocolCtx, soname);
-                    roxieServer.setown(protocolPlugin->createListener(protocol ? protocol : "native", createRoxieProtocolMsgSink(ip, port, numThreads, suspended), port, listenQueue, config, certFileBuf.str(), keyFileBuf.str(), passPhrase));
+                    roxieServer.setown(protocolPlugin->createListener(protocol ? protocol : "native", createRoxieProtocolMsgSink(ip, port, numThreads, suspended), port, listenQueue, config, certFileName.str(), keyFileName.str(), passPhrase));
                 }
                 else
                     roxieServer.setown(createRoxieWorkUnitListener(numThreads, suspended));

+ 13 - 17
roxie/ccd/ccdprotocol.cpp

@@ -279,8 +279,8 @@ public:
         started.signal();
         while (running)
         {
-            ISocket *client = socket->accept(true);
-            ISecureSocket *ssock = nullptr;
+            Owned<ISocket> client = socket->accept(true);
+            Owned<ISecureSocket> ssock;
             if (client)
             {
                 if (protocol && streq(protocol, "ssl"))
@@ -290,11 +290,17 @@ public:
                     {
                         if (!secureContext)
                             secureContext.setown(createSecureSocketContextEx(certFile, keyFile, passPhrase, ServerSocket));
-                        ssock = secureContext->createSecureSocket(client);
+                        ssock.setown(secureContext->createSecureSocket(client.getClear()));
+                        int status = ssock->secure_accept();
+                        if (status < 0)
+                        {
+                            // secure_accept may also DBGLOG() errors ...
+                            WARNLOG("ProtocolSocketListener failure to establish secure connection");
+                            continue;
+                        }
                     }
                     catch (IException *E)
                     {
-                        client->Release();
                         StringBuffer s;
                         E->errorMessage(s);
                         WARNLOG("%s", s.str());
@@ -303,28 +309,18 @@ public:
                     }
                     catch (...)
                     {
-                        client->Release();
-                        WARNLOG("ProtocolSocketListener failure to establish secure connection");
-                        continue;
-                    }
-                    int status = ssock->secure_accept();
-                    if (status < 0)
-                    {
-                        ssock->Release();
-                        client->Release();
-                        // secure_accept may also DBGLOG() errors ...
+                        StringBuffer s;
                         WARNLOG("ProtocolSocketListener failure to establish secure connection");
                         continue;
                     }
-                    client = ssock;
+                    client.setown(ssock.getClear());
 #else
-                    client->Release();
                     WARNLOG("ProtocolSocketListener failure to establish secure connection: OpenSSL disabled in build");
                     continue;
 #endif
                 }
                 client->set_linger(-1);
-                pool->start(client);
+                pool->start(client.getClear());
             }
         }
         DBGLOG("ProtocolSocketListener closed query socket");

+ 17 - 65
tools/testsocket/testsocket.cpp

@@ -53,10 +53,10 @@ unsigned runningQueries;
 unsigned multiThreadMax;
 unsigned maxLineSize = 10000000;
 
-ISocket *persistSocket = nullptr;
+Owned<ISocket> persistSocket;
 bool persistConnections = false;
-ISecureSocketContext *persistSecureContext = nullptr;
-ISecureSocket *persistSSock = nullptr;
+Owned<ISecureSocketContext> persistSecureContext;
+Owned<ISecureSocket> persistSSock;
 
 int repeats = 0;
 StringBuffer queryPrefix;
@@ -374,9 +374,8 @@ int ReceiveThread::run()
 
 int doSendQuery(const char * ip, unsigned port, const char * base)
 {
-    ISocket * socket;
-    ISecureSocketContext *secureContext = nullptr;
-    ISecureSocket *ssock = nullptr;
+    Owned<ISocket> socket;
+    Owned<ISecureSocketContext> secureContext;
     __int64 starttime, endtime;
     StringBuffer ipstr;
     try
@@ -413,38 +412,17 @@ int doSendQuery(const char * ip, unsigned port, const char * base)
             if (!persistSocket)
             {
                 SocketEndpoint ep(ip,port);
-                persistSocket = ISocket::connect_timeout(ep, 1000);
+                persistSocket.setown(ISocket::connect_timeout(ep, 1000));
                 if (useSSL)
                 {
 #ifdef _USE_OPENSSL
-                    try
-                    {
-                        if (!persistSecureContext)
-                            persistSecureContext = createSecureSocketContext(ClientSocket);
-                        persistSSock = persistSecureContext->createSecureSocket(persistSocket);
-                    }
-                    catch (IException *e)
-                    {
-                        persistSocket->Release();
-                        throw e;
-                    }
-                    catch (...)
-                    {
-                        persistSocket->Release();
-                        throw MakeStringException(1, "SSL connect fail");
-                    }
-                    int status = persistSSock->secure_connect();
-                    if (status < 0)
-                    {
-                        // secure_connect may also DBGLOG() errors ...
-                        persistSSock->Release();
-                        persistSocket->Release();
-                        throw MakeStringException(1, "SSL connect fail");
-                    }
-                    persistSocket = persistSSock;
+                    if (!persistSecureContext)
+                        persistSecureContext.setown(createSecureSocketContext(ClientSocket));
+                    persistSSock.setown(persistSecureContext->createSecureSocket(persistSocket.getClear()));
+                    persistSSock->secure_connect();
+                    persistSocket.setown(persistSSock.getClear());
 #else
-                    persistSocket->Release();
-                    throw MakeStringException(1, "OpenSSL disabled in build");
+                    throw MakeStringException(-1, "OpenSSL disabled in build");
 #endif
                 }
             }
@@ -453,37 +431,15 @@ int doSendQuery(const char * ip, unsigned port, const char * base)
         else
         {
             SocketEndpoint ep(ip,port);
-            socket = ISocket::connect_timeout(ep, 1000);
+            socket.setown(ISocket::connect_timeout(ep, 1000));
             if (useSSL)
             {
 #ifdef _USE_OPENSSL
-                try
-                {
-                    secureContext = createSecureSocketContext(ClientSocket);
-                    ssock = secureContext->createSecureSocket(socket);
-                }
-                catch (IException *e)
-                {
-                    socket->Release();
-                    throw e;
-                }
-                catch (...)
-                {
-                    socket->Release();
-                    throw MakeStringException(1, "SSL connect fail");
-                }
-                int status = ssock->secure_connect();
-                if (status < 0)
-                {
-                    // secure_connect may also DBGLOG() errors ...
-                    ssock->Release();
-                    secureContext->Release();
-                    socket->Release();
-                    throw MakeStringException(1, "SSL connect fail");
-                }
-                socket = ssock;
+                secureContext.setown(createSecureSocketContext(ClientSocket));
+                Owned<ISecureSocket> ssock = secureContext->createSecureSocket(socket.getClear());
+                ssock->secure_connect();
+                socket.setown(ssock.getClear());
 #else
-                socket->Release();
                 throw MakeStringException(1, "OpenSSL disabled in build");
 #endif
             }
@@ -635,10 +591,7 @@ int doSendQuery(const char * ip, unsigned port, const char * base)
 
     if (!persistConnections)
     {
-        if (secureContext)
-            secureContext->Release();
         socket->close();
-        socket->Release();
     }
     return 0;
 }
@@ -1028,7 +981,6 @@ int main(int argc, char **argv)
         int sendlen=0;
         persistSocket->write(&sendlen, sizeof(sendlen));
         persistSocket->close();
-        persistSocket->Release();
     }
 
     endtime = get_cycles_now();