1 /**
   2  * Copyright (c) 2010, 2014, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it under
   6  * the terms of the GNU General Public License version 2 only, as published by
   7  * the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT ANY
  10  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
  11  * A PARTICULAR PURPOSE. See the GNU General Public License version 2 for more
  12  * details (a copy is included in the LICENSE file that accompanied this code).
  13  *
  14  * You should have received a copy of the GNU General Public License version 2
  15  * along with this work; if not, write to the Free Software Foundation, Inc., 51
  16  * Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  17  *
  18  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA or
  19  * visit www.oracle.com if you need additional information or have any
  20  * questions.
  21  */
  22 
  23 import java.io.IOException;
  24 import java.io.InputStream;
  25 import java.io.OutputStream;
  26 import java.net.InetSocketAddress;
  27 import java.net.SocketException;
  28 import java.util.Arrays;
  29 import java.util.List;
  30 
  31 import javax.net.ssl.KeyManager;
  32 import javax.net.ssl.SSLContext;
  33 import javax.net.ssl.SSLHandshakeException;
  34 import javax.net.ssl.SSLParameters;
  35 import javax.net.ssl.SSLProtocolException;
  36 import javax.net.ssl.SSLServerSocket;
  37 import javax.net.ssl.SSLServerSocketFactory;
  38 import javax.net.ssl.SSLSocket;
  39 import javax.net.ssl.SSLSocketFactory;
  40 import javax.net.ssl.TrustManager;
  41 
  42 /**
  43  * @test
  44  * @library ../../../../lib/testlibrary/
  45  * @build jdk.testlibrary.Utils
  46  * @compile CipherTestUtils.java
  47  * @summary Test that TLS_FALLBACK_SCSV works in the client and server
  48  * TLS code
  49  */
  50 
  51 public final class FallbackSCSV {
  52     public static void main(String[] args) throws Exception {
  53         FallbackSCSV test = new FallbackSCSV();
  54         test.test();
  55         test.immediateConnect = true;
  56         test.test();
  57         test.cipherTest.checkResult(null);
  58     }
  59     
  60     private static final String SSL_CONTEXT = "TLS";
  61 
  62     final CipherTestUtils cipherTest;
  63     private final MyX509KeyManager keyManager;
  64 
  65     /**
  66      * If true, the socket is connected by the factory.
  67      */
  68     private boolean immediateConnect;
  69     
  70     private FallbackSCSV() throws Exception {
  71         cipherTest = CipherTestUtils.getInstance();
  72         keyManager = new MyX509KeyManager(cipherTest.getClientKeyManager());
  73     }
  74     
  75     private void test() throws Exception {
  76         List<String> protocols = getProtocols();
  77         String maxProtocol = protocols.get(protocols.size() - 1);
  78 
  79         // Without TLS_FALLBACK_SCSV.
  80         run(protocols, protocols, false, maxProtocol);
  81         run(protocols.subList(0, protocols.size() - 1), protocols,
  82                 false, protocols.get(protocols.size() - 2));
  83         run(protocols, protocols.subList(0, protocols.size() - 1),
  84                 false, protocols.get(protocols.size() - 2));
  85 
  86         // With TLS_FALLBACK_SCSV.
  87         run(protocols, protocols, true, maxProtocol);
  88         run(protocols.subList(0, protocols.size() - 1), protocols,
  89                 true, null);
  90         run(protocols, protocols.subList(0, protocols.size() - 1),
  91                 true, protocols.get(protocols.size() - 2));
  92     }
  93     
  94     private List<String> getProtocols() throws Exception {
  95         SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT);
  96         ctx.init(new KeyManager[]{cipherTest.getServerKeyManager()},
  97                 new TrustManager[]{cipherTest.getServerTrustManager()},
  98                 CipherTestUtils.secureRandom);
  99         return Arrays.asList(ctx.getDefaultSSLParameters().getProtocols());
 100     }
 101 
 102     private void run(List<String> clientProtocols,
 103             List<String> serverProtocols,
 104             boolean fallbackSCSV, String expectedProtocol) throws Exception {
 105         Server server = new Server(serverProtocols);
 106         Thread serverThread = new Thread(server);
 107         serverThread.start();
 108         try {
 109             runClient(server, clientProtocols, fallbackSCSV, expectedProtocol);
 110         } catch (Exception e) {
 111             CipherTestUtils.addFailure(e);
 112             System.out.println("Exception:");
 113             e.printStackTrace(System.err);
 114         } finally {
 115             try {
 116                 serverThread.interrupt();
 117                 server.close(); 
 118                 serverThread.join();
 119             } catch (Exception e) {
 120                 CipherTestUtils.addFailure(e);
 121                 System.out.println("Exception:");
 122                 e.printStackTrace(System.err);
 123             }
 124         }
 125     }
 126     
 127     private SSLSocket createClientSocket(
 128             SSLSocketFactory factory, Server server) throws Exception {
 129         if (immediateConnect) {
 130             return (SSLSocket) factory.createSocket(
 131                     server.getAddress().getAddress(),
 132                     server.getAddress().getPort());
 133         } else {
 134             return (SSLSocket) factory.createSocket();
 135         }
 136     }
 137 
 138     private void runClient(Server server, List<String> clientProtocols,
 139             boolean fallbackSCSV, String expectedProtocol) throws Exception {
 140         SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT);
 141         ctx.init(new KeyManager[]{keyManager},
 142                 new TrustManager[]{cipherTest.getClientTrustManager()},
 143                 CipherTestUtils.secureRandom);
 144         SSLSocketFactory factory = (SSLSocketFactory) ctx.getSocketFactory();
 145         try (SSLSocket socket = createClientSocket(factory, server)) {
 146             socket.setEnabledProtocols(clientProtocols.toArray(
 147                     new String[clientProtocols.size()]));
 148             if (fallbackSCSV) {
 149                 SSLParameters params = socket.getSSLParameters();
 150                 params.setSendFallbackSCSV(true);
 151                 socket.setSSLParameters(params);
 152                 if (!socket.getSSLParameters().getSendFallbackSCSV()) {
 153                     throw new Exception("SendFallbackSCSV not set");
 154                 }
 155             } else {
 156                 if (socket.getSSLParameters().getSendFallbackSCSV()) {
 157                     throw new Exception("SendFallbackSCSV set");
 158                 }
 159             }
 160             socket.setSoTimeout(CipherTestUtils.TIMEOUT);
 161             socket.setTcpNoDelay(true);
 162             if (!immediateConnect) {
 163                 socket.connect(server.getAddress());
 164             }
 165             if (expectedProtocol == null) {
 166                 try {
 167                     socket.startHandshake();
 168                 } catch (SSLHandshakeException e) {
 169                     if (!"Received fatal alert: inappropriate_fallback".equals(
 170                             e.getMessage())) {
 171                         CipherTestUtils.addFailure(e);
 172                         System.out.println("Exception:");
 173                         e.printStackTrace(System.err);
 174                     }
 175                     return;
 176                 }
 177                 throw new Exception("No exception for fallback to: "
 178                         + socket.getSession().getProtocol());
 179             }
 180             byte[] buf = {1, 2, 3, 0};
 181             try (OutputStream out = socket.getOutputStream()) {
 182                 out.write(buf, 0, 3);
 183                 try (InputStream in = socket.getInputStream()) {
 184                     for (int i = 0; i < 4; ++i) {
 185                         int b = in.read();
 186                         int expected = (-(i + 1)) & 0xFF;
 187                         if (i == 3) {
 188                             expected = -1;
 189                         }
 190                         if (b != expected) {
 191                             throw new Exception(String.format(
 192                                     "byte mismatch at %d, got %d", i, b));
 193                         }
 194                     }
 195                     
 196                     if (!socket.getSession().getProtocol().equals(
 197                             expectedProtocol)) {
 198                         throw new Exception(String.format(
 199                                 "Protocol is %s, but expected %s",
 200                                 socket.getSession().getProtocol(),
 201                                 expectedProtocol));
 202                     }
 203                 }
 204             }
 205         }
 206     }
 207 
 208     class Server implements Runnable {
 209         private final SSLServerSocket serverSocket;
 210 
 211         Server(List<String> protocols) throws Exception {
 212             SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT);
 213             ctx.init(new KeyManager[]{cipherTest.getServerKeyManager()},
 214                     new TrustManager[]{cipherTest.getServerTrustManager()},
 215                     CipherTestUtils.secureRandom);
 216             SSLServerSocketFactory factory =
 217                     (SSLServerSocketFactory)ctx.getServerSocketFactory();
 218             serverSocket =
 219                     (SSLServerSocket) factory.createServerSocket(0);
 220             serverSocket.setEnabledProtocols(
 221                     protocols.toArray(new String[protocols.size()]));
 222         }
 223         
 224         InetSocketAddress getAddress() {
 225             return (InetSocketAddress) serverSocket.getLocalSocketAddress();
 226         }
 227         
 228         void close() throws Exception {
 229             serverSocket.close();
 230         }
 231         
 232         @Override
 233         public void run() {
 234             while (!Thread.currentThread().isInterrupted()) {
 235                 try (SSLSocket socket = (SSLSocket) serverSocket.accept()) {
 236                     socket.setSoTimeout(CipherTestUtils.TIMEOUT);
 237                     socket.setTcpNoDelay(true);
 238                     
 239                     try {
 240                         socket.startHandshake();
 241                     } catch (SSLHandshakeException e) {
 242                         if (!"Client protocol downgrade is not allowed".equals(
 243                                 e.getMessage())) {
 244                             throw e;
 245                         }
 246                         continue;
 247                     }
 248 
 249                     try (InputStream in = socket.getInputStream();
 250                             OutputStream out = socket.getOutputStream()) {
 251                         for (int i = 0; i < 3; ++i) {
 252                             int b = in.read();
 253                             if (b < 0) {
 254                                 break;
 255                             }
 256                             out.write((-b) & 0xFF);
 257                         }
 258                         out.flush();
 259                     } catch (Exception e) {
 260                         CipherTestUtils.addFailure(e);
 261                         System.out.println("Exception:");
 262                         e.printStackTrace(System.err);
 263                         return;
 264                     }
 265                 } catch (Exception e) {
 266                     if (e.getClass() == SocketException.class
 267                             && "Socket closed".equals(e.getMessage())) {
 268                         // Concurrent close by main thread.
 269                         return;
 270                     }
 271                     CipherTestUtils.addFailure(e);
 272                     System.out.println("Exception:");
 273                     e.printStackTrace(System.err);
 274                     return;
 275                 }
 276             }
 277             try {
 278                 serverSocket.close();
 279             } catch (Exception e) {
 280                 CipherTestUtils.addFailure(e);
 281             }
 282         }
 283     }
 284 }