1 /**
2 * Copyright (c) 2010, 2014, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it under
6 * the terms of the GNU General Public License version 2 only, as published by
7 * the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT ANY
10 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
11 * A PARTICULAR PURPOSE. See the GNU General Public License version 2 for more
12 * details (a copy is included in the LICENSE file that accompanied this code).
13 *
14 * You should have received a copy of the GNU General Public License version 2
15 * along with this work; if not, write to the Free Software Foundation, Inc., 51
16 * Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
17 *
18 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA or
19 * visit www.oracle.com if you need additional information or have any
20 * questions.
21 */
22
23 import java.io.IOException;
24 import java.io.InputStream;
25 import java.io.OutputStream;
26 import java.net.InetSocketAddress;
27 import java.net.SocketException;
28 import java.util.Arrays;
29 import java.util.List;
30
31 import javax.net.ssl.KeyManager;
32 import javax.net.ssl.SSLContext;
33 import javax.net.ssl.SSLHandshakeException;
34 import javax.net.ssl.SSLParameters;
35 import javax.net.ssl.SSLProtocolException;
36 import javax.net.ssl.SSLServerSocket;
37 import javax.net.ssl.SSLServerSocketFactory;
38 import javax.net.ssl.SSLSocket;
39 import javax.net.ssl.SSLSocketFactory;
40 import javax.net.ssl.TrustManager;
41
42 /**
43 * @test
44 * @library ../../../../lib/testlibrary/
45 * @build jdk.testlibrary.Utils
46 * @compile CipherTestUtils.java
47 * @summary Test that TLS_FALLBACK_SCSV works in the client and server
48 * TLS code
49 */
50
51 public final class FallbackSCSV {
52 public static void main(String[] args) throws Exception {
53 FallbackSCSV test = new FallbackSCSV();
54 test.test();
55 test.immediateConnect = true;
56 test.test();
57 test.cipherTest.checkResult(null);
58 }
59
60 private static final String SSL_CONTEXT = "TLS";
61
62 final CipherTestUtils cipherTest;
63 private final MyX509KeyManager keyManager;
64
65 /**
66 * If true, the socket is connected by the factory.
67 */
68 private boolean immediateConnect;
69
70 private FallbackSCSV() throws Exception {
71 cipherTest = CipherTestUtils.getInstance();
72 keyManager = new MyX509KeyManager(cipherTest.getClientKeyManager());
73 }
74
75 private void test() throws Exception {
76 List<String> protocols = getProtocols();
77 String maxProtocol = protocols.get(protocols.size() - 1);
78
79 // Without TLS_FALLBACK_SCSV.
80 run(protocols, protocols, false, maxProtocol);
81 run(protocols.subList(0, protocols.size() - 1), protocols,
82 false, protocols.get(protocols.size() - 2));
83 run(protocols, protocols.subList(0, protocols.size() - 1),
84 false, protocols.get(protocols.size() - 2));
85
86 // With TLS_FALLBACK_SCSV.
87 run(protocols, protocols, true, maxProtocol);
88 run(protocols.subList(0, protocols.size() - 1), protocols,
89 true, null);
90 run(protocols, protocols.subList(0, protocols.size() - 1),
91 true, protocols.get(protocols.size() - 2));
92 }
93
94 private List<String> getProtocols() throws Exception {
95 SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT);
96 ctx.init(new KeyManager[]{cipherTest.getServerKeyManager()},
97 new TrustManager[]{cipherTest.getServerTrustManager()},
98 CipherTestUtils.secureRandom);
99 return Arrays.asList(ctx.getDefaultSSLParameters().getProtocols());
100 }
101
102 private void run(List<String> clientProtocols,
103 List<String> serverProtocols,
104 boolean fallbackSCSV, String expectedProtocol) throws Exception {
105 Server server = new Server(serverProtocols);
106 Thread serverThread = new Thread(server);
107 serverThread.start();
108 try {
109 runClient(server, clientProtocols, fallbackSCSV, expectedProtocol);
110 } catch (Exception e) {
111 CipherTestUtils.addFailure(e);
112 System.out.println("Exception:");
113 e.printStackTrace(System.err);
114 } finally {
115 try {
116 serverThread.interrupt();
117 server.close();
118 serverThread.join();
119 } catch (Exception e) {
120 CipherTestUtils.addFailure(e);
121 System.out.println("Exception:");
122 e.printStackTrace(System.err);
123 }
124 }
125 }
126
127 private SSLSocket createClientSocket(
128 SSLSocketFactory factory, Server server) throws Exception {
129 if (immediateConnect) {
130 return (SSLSocket) factory.createSocket(
131 server.getAddress().getAddress(),
132 server.getAddress().getPort());
133 } else {
134 return (SSLSocket) factory.createSocket();
135 }
136 }
137
138 private void runClient(Server server, List<String> clientProtocols,
139 boolean fallbackSCSV, String expectedProtocol) throws Exception {
140 SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT);
141 ctx.init(new KeyManager[]{keyManager},
142 new TrustManager[]{cipherTest.getClientTrustManager()},
143 CipherTestUtils.secureRandom);
144 SSLSocketFactory factory = (SSLSocketFactory) ctx.getSocketFactory();
145 try (SSLSocket socket = createClientSocket(factory, server)) {
146 socket.setEnabledProtocols(clientProtocols.toArray(
147 new String[clientProtocols.size()]));
148 if (fallbackSCSV) {
149 SSLParameters params = socket.getSSLParameters();
150 params.setSendFallbackSCSV(true);
151 socket.setSSLParameters(params);
152 if (!socket.getSSLParameters().getSendFallbackSCSV()) {
153 throw new Exception("SendFallbackSCSV not set");
154 }
155 } else {
156 if (socket.getSSLParameters().getSendFallbackSCSV()) {
157 throw new Exception("SendFallbackSCSV set");
158 }
159 }
160 socket.setSoTimeout(CipherTestUtils.TIMEOUT);
161 socket.setTcpNoDelay(true);
162 if (!immediateConnect) {
163 socket.connect(server.getAddress());
164 }
165 if (expectedProtocol == null) {
166 try {
167 socket.startHandshake();
168 } catch (SSLHandshakeException e) {
169 if (!"Received fatal alert: inappropriate_fallback".equals(
170 e.getMessage())) {
171 CipherTestUtils.addFailure(e);
172 System.out.println("Exception:");
173 e.printStackTrace(System.err);
174 }
175 return;
176 }
177 throw new Exception("No exception for fallback to: "
178 + socket.getSession().getProtocol());
179 }
180 byte[] buf = {1, 2, 3, 0};
181 try (OutputStream out = socket.getOutputStream()) {
182 out.write(buf, 0, 3);
183 try (InputStream in = socket.getInputStream()) {
184 for (int i = 0; i < 4; ++i) {
185 int b = in.read();
186 int expected = (-(i + 1)) & 0xFF;
187 if (i == 3) {
188 expected = -1;
189 }
190 if (b != expected) {
191 throw new Exception(String.format(
192 "byte mismatch at %d, got %d", i, b));
193 }
194 }
195
196 if (!socket.getSession().getProtocol().equals(
197 expectedProtocol)) {
198 throw new Exception(String.format(
199 "Protocol is %s, but expected %s",
200 socket.getSession().getProtocol(),
201 expectedProtocol));
202 }
203 }
204 }
205 }
206 }
207
208 class Server implements Runnable {
209 private final SSLServerSocket serverSocket;
210
211 Server(List<String> protocols) throws Exception {
212 SSLContext ctx = SSLContext.getInstance(SSL_CONTEXT);
213 ctx.init(new KeyManager[]{cipherTest.getServerKeyManager()},
214 new TrustManager[]{cipherTest.getServerTrustManager()},
215 CipherTestUtils.secureRandom);
216 SSLServerSocketFactory factory =
217 (SSLServerSocketFactory)ctx.getServerSocketFactory();
218 serverSocket =
219 (SSLServerSocket) factory.createServerSocket(0);
220 serverSocket.setEnabledProtocols(
221 protocols.toArray(new String[protocols.size()]));
222 }
223
224 InetSocketAddress getAddress() {
225 return (InetSocketAddress) serverSocket.getLocalSocketAddress();
226 }
227
228 void close() throws Exception {
229 serverSocket.close();
230 }
231
232 @Override
233 public void run() {
234 while (!Thread.currentThread().isInterrupted()) {
235 try (SSLSocket socket = (SSLSocket) serverSocket.accept()) {
236 socket.setSoTimeout(CipherTestUtils.TIMEOUT);
237 socket.setTcpNoDelay(true);
238
239 try {
240 socket.startHandshake();
241 } catch (SSLHandshakeException e) {
242 if (!"Client protocol downgrade is not allowed".equals(
243 e.getMessage())) {
244 throw e;
245 }
246 continue;
247 }
248
249 try (InputStream in = socket.getInputStream();
250 OutputStream out = socket.getOutputStream()) {
251 for (int i = 0; i < 3; ++i) {
252 int b = in.read();
253 if (b < 0) {
254 break;
255 }
256 out.write((-b) & 0xFF);
257 }
258 out.flush();
259 } catch (Exception e) {
260 CipherTestUtils.addFailure(e);
261 System.out.println("Exception:");
262 e.printStackTrace(System.err);
263 return;
264 }
265 } catch (Exception e) {
266 if (e.getClass() == SocketException.class
267 && "Socket closed".equals(e.getMessage())) {
268 // Concurrent close by main thread.
269 return;
270 }
271 CipherTestUtils.addFailure(e);
272 System.out.println("Exception:");
273 e.printStackTrace(System.err);
274 return;
275 }
276 }
277 try {
278 serverSocket.close();
279 } catch (Exception e) {
280 CipherTestUtils.addFailure(e);
281 }
282 }
283 }
284 }