1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package nl.altindag.ssl.socket;
18
19 import org.junit.jupiter.api.Test;
20 import org.junit.jupiter.api.extension.ExtendWith;
21 import org.mockito.junit.jupiter.MockitoExtension;
22
23 import javax.net.ssl.SSLParameters;
24 import javax.net.ssl.SSLSocket;
25 import javax.net.ssl.SSLSocketFactory;
26 import java.io.ByteArrayInputStream;
27 import java.io.IOException;
28 import java.io.InputStream;
29 import java.net.InetAddress;
30 import java.net.Socket;
31
32 import static org.assertj.core.api.Assertions.assertThat;
33 import static org.mockito.ArgumentMatchers.any;
34 import static org.mockito.ArgumentMatchers.anyBoolean;
35 import static org.mockito.ArgumentMatchers.anyInt;
36 import static org.mockito.ArgumentMatchers.anyString;
37 import static org.mockito.Mockito.doReturn;
38 import static org.mockito.Mockito.mock;
39 import static org.mockito.Mockito.spy;
40 import static org.mockito.Mockito.times;
41 import static org.mockito.Mockito.verify;
42 import static org.mockito.Mockito.verifyNoInteractions;
43
44
45
46
47 @ExtendWith(MockitoExtension.class)
48 class CompositeSSLSocketFactoryShould {
49
50 private final SSLParameters sslParameters = spy(
51 new SSLParameters(
52 new String[] {"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384"},
53 new String[] {"TLSv1.2"}
54 )
55 );
56
57 private final SSLSocketFactory sslSocketFactory = mock(SSLSocketFactory.class);
58
59 private final CompositeSSLSocketFactory victim = new CompositeSSLSocketFactory(sslSocketFactory, sslParameters);
60
61 @Test
62 void returnDefaultCipherSuites() {
63 String[] defaultCipherSuites = victim.getDefaultCipherSuites();
64
65 assertThat(defaultCipherSuites).containsExactly("TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384");
66 verify(sslParameters, times(1)).getCipherSuites();
67 }
68
69 @Test
70 void returnSupportedCipherSuites() {
71 String[] supportedCipherSuites = victim.getSupportedCipherSuites();
72
73 assertThat(supportedCipherSuites).containsExactly("TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384");
74 verify(sslParameters, times(1)).getCipherSuites();
75 }
76
77 @Test
78 void createSocket() throws IOException {
79 SSLSocket mockedSslSocket = mock(SSLSocket.class);
80
81 doReturn(mockedSslSocket).when(sslSocketFactory).createSocket();
82
83 Socket socket = victim.createSocket();
84
85 assertThat(socket).isNotNull();
86 verify(sslSocketFactory, times(1)).createSocket();
87 verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
88 }
89
90 @Test
91 void createSocketDoesNotUseSslParametersWhenInnerSslSocketFactoryReturnsSocket() throws IOException {
92 Socket mockedSocket = mock(Socket.class);
93
94 doReturn(mockedSocket).when(sslSocketFactory).createSocket();
95
96 Socket socket = victim.createSocket();
97
98 assertThat(socket).isNotNull();
99 verify(sslSocketFactory, times(1)).createSocket();
100 verifyNoInteractions(mockedSocket);
101 }
102
103 @Test
104 void createSocketWithSocketInputStreamAutoClosable() throws IOException {
105 Socket baseSocket = mock(SSLSocket.class);
106 SSLSocket mockedSslSocket = mock(SSLSocket.class);
107 InputStream inputStream = new ByteArrayInputStream(new byte[]{});
108
109 doReturn(mockedSslSocket)
110 .when(sslSocketFactory).createSocket(any(Socket.class), any(InputStream.class), anyBoolean());
111
112 Socket socket = victim.createSocket(baseSocket, inputStream, true);
113
114 assertThat(socket).isNotNull();
115 verify(sslSocketFactory, times(1)).createSocket(baseSocket, inputStream, true);
116 verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
117 }
118
119 @Test
120 void createSocketWithSocketHostPortAutoClosable() throws IOException {
121 Socket baseSocket = mock(SSLSocket.class);
122 SSLSocket mockedSslSocket = mock(SSLSocket.class);
123
124 doReturn(mockedSslSocket)
125 .when(sslSocketFactory).createSocket(any(Socket.class), anyString(), anyInt(), anyBoolean());
126
127 Socket socket = victim.createSocket(baseSocket, "localhost", 8443, true);
128
129 assertThat(socket).isNotNull();
130 verify(sslSocketFactory, times(1)).createSocket(baseSocket, "localhost", 8443, true);
131 verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
132 }
133
134 @Test
135 void createSocketWithHostPort() throws IOException {
136 SSLSocket mockedSslSocket = mock(SSLSocket.class);
137
138 doReturn(mockedSslSocket)
139 .when(sslSocketFactory).createSocket(anyString(), anyInt());
140
141 Socket socket = victim.createSocket("localhost", 8443);
142
143 assertThat(socket).isNotNull();
144 verify(sslSocketFactory, times(1)).createSocket("localhost", 8443);
145 verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
146 }
147
148 @Test
149 void createSocketWithHostPortLocalAddressLocalPort() throws IOException {
150 SSLSocket mockedSslSocket = mock(SSLSocket.class);
151
152 doReturn(mockedSslSocket)
153 .when(sslSocketFactory).createSocket(anyString(), anyInt(), any(InetAddress.class), anyInt());
154
155 Socket socket = victim.createSocket("localhost", 8443, InetAddress.getLocalHost(), 1234);
156
157 assertThat(socket).isNotNull();
158 verify(sslSocketFactory, times(1)).createSocket("localhost", 8443, InetAddress.getLocalHost(), 1234);
159 verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
160 }
161
162 @Test
163 void createSocketWithAddressPort() throws IOException {
164 SSLSocket mockedSslSocket = mock(SSLSocket.class);
165
166 doReturn(mockedSslSocket)
167 .when(sslSocketFactory).createSocket(any(InetAddress.class), anyInt());
168
169 Socket socket = victim.createSocket(InetAddress.getLocalHost(), 1234);
170
171 assertThat(socket).isNotNull();
172 verify(sslSocketFactory, times(1)).createSocket(InetAddress.getLocalHost(), 1234);
173 verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
174 }
175
176 @Test
177 void createSocketWithAddressPortLocalAddressPort() throws IOException {
178 SSLSocket mockedSslSocket = mock(SSLSocket.class);
179
180 doReturn(mockedSslSocket)
181 .when(sslSocketFactory).createSocket(any(InetAddress.class), anyInt(), any(InetAddress.class), anyInt());
182
183 Socket socket = victim.createSocket(InetAddress.getLocalHost(), 1234, InetAddress.getLocalHost(), 4321);
184
185 assertThat(socket).isNotNull();
186 verify(sslSocketFactory, times(1)).createSocket(InetAddress.getLocalHost(), 1234, InetAddress.getLocalHost(), 4321);
187 verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
188 }
189
190 }