View Javadoc
1   /*
2    * Copyright 2019-2021 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package nl.altindag.ssl.socket;
18  
19  import org.junit.jupiter.api.Test;
20  import org.junit.jupiter.api.extension.ExtendWith;
21  import org.mockito.junit.jupiter.MockitoExtension;
22  
23  import javax.net.ssl.SSLParameters;
24  import javax.net.ssl.SSLSocket;
25  import javax.net.ssl.SSLSocketFactory;
26  import java.io.ByteArrayInputStream;
27  import java.io.IOException;
28  import java.io.InputStream;
29  import java.net.InetAddress;
30  import java.net.Socket;
31  
32  import static org.assertj.core.api.Assertions.assertThat;
33  import static org.mockito.ArgumentMatchers.any;
34  import static org.mockito.ArgumentMatchers.anyBoolean;
35  import static org.mockito.ArgumentMatchers.anyInt;
36  import static org.mockito.ArgumentMatchers.anyString;
37  import static org.mockito.Mockito.doReturn;
38  import static org.mockito.Mockito.mock;
39  import static org.mockito.Mockito.spy;
40  import static org.mockito.Mockito.times;
41  import static org.mockito.Mockito.verify;
42  import static org.mockito.Mockito.verifyNoInteractions;
43  
44  /**
45   * @author Hakan Altindag
46   */
47  @ExtendWith(MockitoExtension.class)
48  class CompositeSSLSocketFactoryShould {
49  
50      private final SSLParameters sslParameters = spy(
51              new SSLParameters(
52                      new String[] {"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384"},
53                      new String[] {"TLSv1.2"}
54              )
55      );
56  
57      private final SSLSocketFactory sslSocketFactory = mock(SSLSocketFactory.class);
58  
59      private final CompositeSSLSocketFactory victim = new CompositeSSLSocketFactory(sslSocketFactory, sslParameters);
60  
61      @Test
62      void returnDefaultCipherSuites() {
63          String[] defaultCipherSuites = victim.getDefaultCipherSuites();
64  
65          assertThat(defaultCipherSuites).containsExactly("TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384");
66          verify(sslParameters, times(1)).getCipherSuites();
67      }
68  
69      @Test
70      void returnSupportedCipherSuites() {
71          String[] supportedCipherSuites = victim.getSupportedCipherSuites();
72  
73          assertThat(supportedCipherSuites).containsExactly("TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384");
74          verify(sslParameters, times(1)).getCipherSuites();
75      }
76  
77      @Test
78      void createSocket() throws IOException {
79          SSLSocket mockedSslSocket = mock(SSLSocket.class);
80  
81          doReturn(mockedSslSocket).when(sslSocketFactory).createSocket();
82  
83          Socket socket = victim.createSocket();
84  
85          assertThat(socket).isNotNull();
86          verify(sslSocketFactory, times(1)).createSocket();
87          verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
88      }
89  
90      @Test
91      void createSocketDoesNotUseSslParametersWhenInnerSslSocketFactoryReturnsSocket() throws IOException {
92          Socket mockedSocket = mock(Socket.class);
93  
94          doReturn(mockedSocket).when(sslSocketFactory).createSocket();
95  
96          Socket socket = victim.createSocket();
97  
98          assertThat(socket).isNotNull();
99          verify(sslSocketFactory, times(1)).createSocket();
100         verifyNoInteractions(mockedSocket);
101     }
102 
103     @Test
104     void createSocketWithSocketInputStreamAutoClosable() throws IOException {
105         Socket baseSocket = mock(SSLSocket.class);
106         SSLSocket mockedSslSocket = mock(SSLSocket.class);
107         InputStream inputStream = new ByteArrayInputStream(new byte[]{});
108 
109         doReturn(mockedSslSocket)
110                 .when(sslSocketFactory).createSocket(any(Socket.class), any(InputStream.class), anyBoolean());
111 
112         Socket socket = victim.createSocket(baseSocket, inputStream, true);
113 
114         assertThat(socket).isNotNull();
115         verify(sslSocketFactory, times(1)).createSocket(baseSocket, inputStream, true);
116         verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
117     }
118 
119     @Test
120     void createSocketWithSocketHostPortAutoClosable() throws IOException {
121         Socket baseSocket = mock(SSLSocket.class);
122         SSLSocket mockedSslSocket = mock(SSLSocket.class);
123 
124         doReturn(mockedSslSocket)
125                 .when(sslSocketFactory).createSocket(any(Socket.class), anyString(), anyInt(), anyBoolean());
126 
127         Socket socket = victim.createSocket(baseSocket, "localhost", 8443, true);
128 
129         assertThat(socket).isNotNull();
130         verify(sslSocketFactory, times(1)).createSocket(baseSocket, "localhost", 8443, true);
131         verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
132     }
133 
134     @Test
135     void createSocketWithHostPort() throws IOException {
136         SSLSocket mockedSslSocket = mock(SSLSocket.class);
137 
138         doReturn(mockedSslSocket)
139                 .when(sslSocketFactory).createSocket(anyString(), anyInt());
140 
141         Socket socket = victim.createSocket("localhost", 8443);
142 
143         assertThat(socket).isNotNull();
144         verify(sslSocketFactory, times(1)).createSocket("localhost", 8443);
145         verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
146     }
147 
148     @Test
149     void createSocketWithHostPortLocalAddressLocalPort() throws IOException {
150         SSLSocket mockedSslSocket = mock(SSLSocket.class);
151 
152         doReturn(mockedSslSocket)
153                 .when(sslSocketFactory).createSocket(anyString(), anyInt(), any(InetAddress.class), anyInt());
154 
155         Socket socket = victim.createSocket("localhost", 8443, InetAddress.getLocalHost(), 1234);
156 
157         assertThat(socket).isNotNull();
158         verify(sslSocketFactory, times(1)).createSocket("localhost", 8443, InetAddress.getLocalHost(), 1234);
159         verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
160     }
161 
162     @Test
163     void createSocketWithAddressPort() throws IOException {
164         SSLSocket mockedSslSocket = mock(SSLSocket.class);
165 
166         doReturn(mockedSslSocket)
167                 .when(sslSocketFactory).createSocket(any(InetAddress.class), anyInt());
168 
169         Socket socket = victim.createSocket(InetAddress.getLocalHost(), 1234);
170 
171         assertThat(socket).isNotNull();
172         verify(sslSocketFactory, times(1)).createSocket(InetAddress.getLocalHost(), 1234);
173         verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
174     }
175 
176     @Test
177     void createSocketWithAddressPortLocalAddressPort() throws IOException {
178         SSLSocket mockedSslSocket = mock(SSLSocket.class);
179 
180         doReturn(mockedSslSocket)
181                 .when(sslSocketFactory).createSocket(any(InetAddress.class), anyInt(), any(InetAddress.class), anyInt());
182 
183         Socket socket = victim.createSocket(InetAddress.getLocalHost(), 1234, InetAddress.getLocalHost(), 4321);
184 
185         assertThat(socket).isNotNull();
186         verify(sslSocketFactory, times(1)).createSocket(InetAddress.getLocalHost(), 1234, InetAddress.getLocalHost(), 4321);
187         verify(mockedSslSocket, times(1)).setSSLParameters(sslParameters);
188     }
189 
190 }