1 /** 2 * Copyright (c) 2010, 2014, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it under 6 * the terms of the GNU General Public License version 2 only, as published by 7 * the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT ANY 10 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 * A PARTICULAR PURPOSE. See the GNU General Public License version 2 for more 12 * details (a copy is included in the LICENSE file that accompanied this code). 13 * 14 * You should have received a copy of the GNU General Public License version 2 15 * along with this work; if not, write to the Free Software Foundation, Inc., 51 16 * Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 17 * 18 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA or 19 * visit www.oracle.com if you need additional information or have any 20 * questions. 21 */ 22 23 import java.io.IOException; 24 import java.io.InputStream; 25 import java.io.OutputStream; 26 import java.net.InetSocketAddress; 27 import java.net.SocketException; 28 import java.util.Arrays; 29 import java.util.List; 30 31 import javax.net.ssl.KeyManager; 32 import javax.net.ssl.SSLContext; 33 import javax.net.ssl.SSLHandshakeException; 34 import javax.net.ssl.SSLParameters; 35 import javax.net.ssl.SSLProtocolException; 36 import javax.net.ssl.SSLServerSocket; 37 import javax.net.ssl.SSLServerSocketFactory; 38 import javax.net.ssl.SSLSocket; 39 import javax.net.ssl.SSLSocketFactory; 40 import javax.net.ssl.TrustManager; 41 42 /** 43 * @test 44 * @library ../../../../lib/testlibrary/ 45 * @build jdk.testlibrary.Utils 46 * @compile CipherTestUtils.java 47 * @summary Test that TLS_FALLBACK_SCSV works in the client and server 48 * TLS code 49 */ 50 51 public final class FallbackSCSV { 52 public static void main(String[] args) throws Exception { 53 FallbackSCSV test = new FallbackSCSV(); 54 test.test(); 55 test.immediateConnect = true; 56 test.test(); 57 test.cipherTest.checkResult(null); 58 } 59 60 private static final String SSL_CONTEXT = "TLS"; 61 62 final CipherTestUtils cipherTest; 63 private final MyX509KeyManager keyManager; 64 65 /** 66 * If true, the socket is connected by the factory. 67 */ 68 private boolean immediateConnect; 69 70 private FallbackSCSV() throws Exception { 71 cipherTest = CipherTestUtils.getInstance(); 72 keyManager = new MyX509KeyManager(cipherTest.getClientKeyManager()); 73 } 74 75 private void test() throws Exception { 76 List<String> protocols = getProtocols(); 77 String maxProtocol = protocols.get(protocols.size() - 1); 78 79 // Without TLS_FALLBACK_SCSV. 80 run(protocols, protocols, false, maxProtocol); 81 run(protocols.subList(0, protocols.size() - 1), protocols, 82 false, protocols.get(protocols.size() - 2)); 83 run(protocols, protocols.subList(0, protocols.size() - 1), 84 false, protocols.get(protocols.size() - 2)); 85 86 // With TLS_FALLBACK_SCSV. 87 run(protocols, protocols, true, maxProtocol); 88 run(protocols.subList(0, protocols.size() - 1), protocols, 89 true, null); 90 run(protocols, protocols.subList(0, protocols.size() - 1), 91 true, protocols.get(protocols.size() - 2)); 92 } 93 94 private List<String> getProtocols() throws Exception { 95 SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT); 96 ctx.init(new KeyManager[]{cipherTest.getServerKeyManager()}, 97 new TrustManager[]{cipherTest.getServerTrustManager()}, 98 CipherTestUtils.secureRandom); 99 return Arrays.asList(ctx.getDefaultSSLParameters().getProtocols()); 100 } 101 102 private void run(List<String> clientProtocols, 103 List<String> serverProtocols, 104 boolean fallbackSCSV, String expectedProtocol) throws Exception { 105 Server server = new Server(serverProtocols); 106 Thread serverThread = new Thread(server); 107 serverThread.start(); 108 try { 109 runClient(server, clientProtocols, fallbackSCSV, expectedProtocol); 110 } catch (Exception e) { 111 CipherTestUtils.addFailure(e); 112 System.out.println("Exception:"); 113 e.printStackTrace(System.err); 114 } finally { 115 try { 116 serverThread.interrupt(); 117 server.close(); 118 serverThread.join(); 119 } catch (Exception e) { 120 CipherTestUtils.addFailure(e); 121 System.out.println("Exception:"); 122 e.printStackTrace(System.err); 123 } 124 } 125 } 126 127 private SSLSocket createClientSocket( 128 SSLSocketFactory factory, Server server) throws Exception { 129 if (immediateConnect) { 130 return (SSLSocket) factory.createSocket( 131 server.getAddress().getAddress(), 132 server.getAddress().getPort()); 133 } else { 134 return (SSLSocket) factory.createSocket(); 135 } 136 } 137 138 private void runClient(Server server, List<String> clientProtocols, 139 boolean fallbackSCSV, String expectedProtocol) throws Exception { 140 SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT); 141 ctx.init(new KeyManager[]{keyManager}, 142 new TrustManager[]{cipherTest.getClientTrustManager()}, 143 CipherTestUtils.secureRandom); 144 SSLSocketFactory factory = (SSLSocketFactory) ctx.getSocketFactory(); 145 try (SSLSocket socket = createClientSocket(factory, server)) { 146 socket.setEnabledProtocols(clientProtocols.toArray( 147 new String[clientProtocols.size()])); 148 if (fallbackSCSV) { 149 SSLParameters params = socket.getSSLParameters(); 150 params.setSendFallbackSCSV(true); 151 socket.setSSLParameters(params); 152 if (!socket.getSSLParameters().getSendFallbackSCSV()) { 153 throw new Exception("SendFallbackSCSV not set"); 154 } 155 } else { 156 if (socket.getSSLParameters().getSendFallbackSCSV()) { 157 throw new Exception("SendFallbackSCSV set"); 158 } 159 } 160 socket.setSoTimeout(CipherTestUtils.TIMEOUT); 161 socket.setTcpNoDelay(true); 162 if (!immediateConnect) { 163 socket.connect(server.getAddress()); 164 } 165 if (expectedProtocol == null) { 166 try { 167 socket.startHandshake(); 168 } catch (SSLHandshakeException e) { 169 if (!"Received fatal alert: inappropriate_fallback".equals( 170 e.getMessage())) { 171 CipherTestUtils.addFailure(e); 172 System.out.println("Exception:"); 173 e.printStackTrace(System.err); 174 } 175 return; 176 } 177 throw new Exception("No exception for fallback to: " 178 + socket.getSession().getProtocol()); 179 } 180 byte[] buf = {1, 2, 3, 0}; 181 try (OutputStream out = socket.getOutputStream()) { 182 out.write(buf, 0, 3); 183 try (InputStream in = socket.getInputStream()) { 184 for (int i = 0; i < 4; ++i) { 185 int b = in.read(); 186 int expected = (-(i + 1)) & 0xFF; 187 if (i == 3) { 188 expected = -1; 189 } 190 if (b != expected) { 191 throw new Exception(String.format( 192 "byte mismatch at %d, got %d", i, b)); 193 } 194 } 195 196 if (!socket.getSession().getProtocol().equals( 197 expectedProtocol)) { 198 throw new Exception(String.format( 199 "Protocol is %s, but expected %s", 200 socket.getSession().getProtocol(), 201 expectedProtocol)); 202 } 203 } 204 } 205 } 206 } 207 208 class Server implements Runnable { 209 private final SSLServerSocket serverSocket; 210 211 Server(List<String> protocols) throws Exception { 212 SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT); 213 ctx.init(new KeyManager[]{cipherTest.getServerKeyManager()}, 214 new TrustManager[]{cipherTest.getServerTrustManager()}, 215 CipherTestUtils.secureRandom); 216 SSLServerSocketFactory factory = 217 (SSLServerSocketFactory)ctx.getServerSocketFactory(); 218 serverSocket = 219 (SSLServerSocket) factory.createServerSocket(0); 220 serverSocket.setEnabledProtocols( 221 protocols.toArray(new String[protocols.size()])); 222 } 223 224 InetSocketAddress getAddress() { 225 return (InetSocketAddress) serverSocket.getLocalSocketAddress(); 226 } 227 228 void close() throws Exception { 229 serverSocket.close(); 230 } 231 232 @Override 233 public void run() { 234 while (!Thread.currentThread().isInterrupted()) { 235 try (SSLSocket socket = (SSLSocket) serverSocket.accept()) { 236 socket.setSoTimeout(CipherTestUtils.TIMEOUT); 237 socket.setTcpNoDelay(true); 238 239 try { 240 socket.startHandshake(); 241 } catch (SSLHandshakeException e) { 242 if (!"Client protocol downgrade is not allowed".equals( 243 e.getMessage())) { 244 throw e; 245 } 246 continue; 247 } 248 249 try (InputStream in = socket.getInputStream(); 250 OutputStream out = socket.getOutputStream()) { 251 for (int i = 0; i < 3; ++i) { 252 int b = in.read(); 253 if (b < 0) { 254 break; 255 } 256 out.write((-b) & 0xFF); 257 } 258 out.flush(); 259 } catch (Exception e) { 260 CipherTestUtils.addFailure(e); 261 System.out.println("Exception:"); 262 e.printStackTrace(System.err); 263 return; 264 } 265 } catch (Exception e) { 266 if (e.getClass() == SocketException.class 267 && "Socket closed".equals(e.getMessage())) { 268 // Concurrent close by main thread. 269 return; 270 } 271 CipherTestUtils.addFailure(e); 272 System.out.println("Exception:"); 273 e.printStackTrace(System.err); 274 return; 275 } 276 } 277 try { 278 serverSocket.close(); 279 } catch (Exception e) { 280 CipherTestUtils.addFailure(e); 281 } 282 } 283 } 284 }