View Javadoc
1   /*
2    * Copyright 2019-2021 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package nl.altindag.ssl.util;
18  
19  import nl.altindag.ssl.SSLFactory;
20  import org.junit.jupiter.api.Test;
21  import org.junit.jupiter.api.extension.ExtendWith;
22  import org.mockito.junit.jupiter.MockitoExtension;
23  
24  import javax.net.ssl.SSLContext;
25  import javax.net.ssl.SSLSession;
26  import javax.net.ssl.SSLSessionContext;
27  import java.time.ZonedDateTime;
28  import java.util.Collections;
29  import java.util.List;
30  
31  import static org.assertj.core.api.Assertions.assertThat;
32  import static org.assertj.core.api.Assertions.assertThatThrownBy;
33  import static org.mockito.ArgumentMatchers.any;
34  import static org.mockito.Mockito.mock;
35  import static org.mockito.Mockito.times;
36  import static org.mockito.Mockito.verify;
37  import static org.mockito.Mockito.when;
38  
39  /**
40   * @author Hakan Altindag
41   */
42  @ExtendWith(MockitoExtension.class)
43  class SSLSessionUtilsShould {
44  
45      @Test
46      void invalidateCachesWithSslFactory() {
47          SSLFactory sslFactory = mock(SSLFactory.class);
48          SSLContext sslContext = mock(SSLContext.class);
49          SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
50          SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
51          SSLSession clientSession = mock(SSLSession.class);
52          SSLSession serverSession = mock(SSLSession.class);
53  
54          when(sslFactory.getSslContext()).thenReturn(sslContext);
55          when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
56          when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
57  
58          when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
59          when(serverSessionContext.getSession(any())).thenReturn(serverSession);
60  
61          when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
62          when(clientSessionContext.getSession(any())).thenReturn(clientSession);
63  
64          SSLSessionUtils.invalidateCaches(sslFactory);
65  
66          verify(serverSession, times(1)).invalidate();
67          verify(clientSession, times(1)).invalidate();
68      }
69  
70      @Test
71      void invalidateCachesWithSslContext() {
72          SSLFactory sslFactory = mock(SSLFactory.class);
73          SSLContext sslContext = mock(SSLContext.class);
74          SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
75          SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
76          SSLSession clientSession = mock(SSLSession.class);
77          SSLSession serverSession = mock(SSLSession.class);
78  
79          when(sslFactory.getSslContext()).thenReturn(sslContext);
80          when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
81          when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
82  
83          when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
84          when(serverSessionContext.getSession(any())).thenReturn(serverSession);
85  
86          when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
87          when(clientSessionContext.getSession(any())).thenReturn(clientSession);
88  
89          SSLSessionUtils.invalidateCaches(sslFactory.getSslContext());
90  
91          verify(serverSession, times(1)).invalidate();
92          verify(clientSession, times(1)).invalidate();
93      }
94  
95      @Test
96      void invalidateCachesBeforeGivenTimeStamp() {
97          SSLFactory sslFactory = mock(SSLFactory.class);
98          SSLContext sslContext = mock(SSLContext.class);
99          SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
100         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
101         SSLSession clientSession = mock(SSLSession.class);
102         SSLSession serverSession = mock(SSLSession.class);
103 
104         when(sslFactory.getSslContext()).thenReturn(sslContext);
105         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
106         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
107 
108         when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
109         when(serverSessionContext.getSession(any())).thenReturn(serverSession);
110         when(serverSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(1).toInstant().toEpochMilli());
111 
112         when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
113         when(clientSessionContext.getSession(any())).thenReturn(clientSession);
114         when(clientSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(1).toInstant().toEpochMilli());
115 
116         SSLSessionUtils.invalidateCachesBefore(sslFactory, ZonedDateTime.now());
117 
118         verify(serverSession, times(1)).invalidate();
119         verify(clientSession, times(1)).invalidate();
120     }
121 
122     @Test
123     void notInvalidateCachesWhenSessionTimeIsAheadOfGivenTimeStamp() {
124         SSLFactory sslFactory = mock(SSLFactory.class);
125         SSLContext sslContext = mock(SSLContext.class);
126         SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
127         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
128         SSLSession clientSession = mock(SSLSession.class);
129         SSLSession serverSession = mock(SSLSession.class);
130 
131         when(sslFactory.getSslContext()).thenReturn(sslContext);
132         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
133         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
134 
135         when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
136         when(serverSessionContext.getSession(any())).thenReturn(serverSession);
137         when(serverSession.getCreationTime()).thenReturn(ZonedDateTime.now().plusHours(1).toInstant().toEpochMilli());
138 
139         when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
140         when(clientSessionContext.getSession(any())).thenReturn(clientSession);
141         when(clientSession.getCreationTime()).thenReturn(ZonedDateTime.now().plusHours(1).toInstant().toEpochMilli());
142 
143         SSLSessionUtils.invalidateCachesBefore(sslFactory, ZonedDateTime.now());
144 
145         verify(serverSession, times(0)).invalidate();
146         verify(clientSession, times(0)).invalidate();
147     }
148 
149     @Test
150     void invalidateCachesAfterGivenTimeStamp() {
151         SSLFactory sslFactory = mock(SSLFactory.class);
152         SSLContext sslContext = mock(SSLContext.class);
153         SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
154         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
155         SSLSession clientSession = mock(SSLSession.class);
156         SSLSession serverSession = mock(SSLSession.class);
157 
158         when(sslFactory.getSslContext()).thenReturn(sslContext);
159         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
160         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
161 
162         when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
163         when(serverSessionContext.getSession(any())).thenReturn(serverSession);
164         when(serverSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(1).toInstant().toEpochMilli());
165 
166         when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
167         when(clientSessionContext.getSession(any())).thenReturn(clientSession);
168         when(clientSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(1).toInstant().toEpochMilli());
169 
170         SSLSessionUtils.invalidateCachesAfter(sslFactory, ZonedDateTime.now().minusHours(2));
171 
172         verify(serverSession, times(1)).invalidate();
173         verify(clientSession, times(1)).invalidate();
174     }
175 
176     @Test
177     void notInvalidateCachesWhenSessionTimeIsBeforeOfGivenTimeStamp() {
178         SSLFactory sslFactory = mock(SSLFactory.class);
179         SSLContext sslContext = mock(SSLContext.class);
180         SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
181         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
182         SSLSession clientSession = mock(SSLSession.class);
183         SSLSession serverSession = mock(SSLSession.class);
184 
185         when(sslFactory.getSslContext()).thenReturn(sslContext);
186         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
187         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
188 
189         when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
190         when(serverSessionContext.getSession(any())).thenReturn(serverSession);
191         when(serverSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(3).toInstant().toEpochMilli());
192 
193         when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
194         when(clientSessionContext.getSession(any())).thenReturn(clientSession);
195         when(clientSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(3).toInstant().toEpochMilli());
196 
197         SSLSessionUtils.invalidateCachesAfter(sslFactory, ZonedDateTime.now().minusHours(2));
198 
199         verify(serverSession, times(0)).invalidate();
200         verify(clientSession, times(0)).invalidate();
201     }
202 
203     @Test
204     void invalidateCachesBetweenGivenTimeStamp() {
205         SSLFactory sslFactory = mock(SSLFactory.class);
206         SSLContext sslContext = mock(SSLContext.class);
207         SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
208         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
209         SSLSession clientSession = mock(SSLSession.class);
210         SSLSession serverSession = mock(SSLSession.class);
211 
212         when(sslFactory.getSslContext()).thenReturn(sslContext);
213         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
214         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
215 
216         when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
217         when(serverSessionContext.getSession(any())).thenReturn(serverSession);
218         when(serverSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(1).toInstant().toEpochMilli());
219 
220         when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
221         when(clientSessionContext.getSession(any())).thenReturn(clientSession);
222         when(clientSession.getCreationTime()).thenReturn(ZonedDateTime.now().minusHours(1).toInstant().toEpochMilli());
223 
224         SSLSessionUtils.invalidateCachesBetween(sslFactory, ZonedDateTime.now().minusHours(2), ZonedDateTime.now());
225 
226         verify(serverSession, times(1)).invalidate();
227         verify(clientSession, times(1)).invalidate();
228     }
229 
230     @Test
231     void updateSessionTimeout() {
232         SSLFactory sslFactory = mock(SSLFactory.class);
233         SSLContext sslContext = mock(SSLContext.class);
234         SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
235         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
236 
237         when(sslFactory.getSslContext()).thenReturn(sslContext);
238         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
239         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
240 
241         SSLSessionUtils.updateSessionTimeout(sslFactory, 10);
242 
243         verify(serverSessionContext, times(1)).setSessionTimeout(10);
244         verify(clientSessionContext, times(1)).setSessionTimeout(10);
245     }
246 
247     @Test
248     void updateSessionCacheSize() {
249         SSLFactory sslFactory = mock(SSLFactory.class);
250         SSLContext sslContext = mock(SSLContext.class);
251         SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
252         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
253 
254         when(sslFactory.getSslContext()).thenReturn(sslContext);
255         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
256         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
257 
258         SSLSessionUtils.updateSessionCacheSize(sslFactory, 1024);
259 
260         verify(serverSessionContext, times(1)).setSessionCacheSize(1024);
261         verify(clientSessionContext, times(1)).setSessionCacheSize(1024);
262     }
263 
264     @Test
265     void getServerSslSessions() {
266         SSLFactory sslFactory = mock(SSLFactory.class);
267         SSLContext sslContext = mock(SSLContext.class);
268         SSLSessionContext serverSessionContext = mock(SSLSessionContext.class);
269         SSLSession serverSession = mock(SSLSession.class);
270 
271         when(sslFactory.getSslContext()).thenReturn(sslContext);
272         when(sslContext.getServerSessionContext()).thenReturn(serverSessionContext);
273 
274         when(serverSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
275         when(serverSessionContext.getSession(any())).thenReturn(serverSession);
276 
277         List<SSLSession> serverSslSessions = SSLSessionUtils.getServerSslSessions(sslFactory);
278 
279         assertThat(serverSslSessions).hasSize(1);
280     }
281 
282     @Test
283     void getClientSslSessions() {
284         SSLFactory sslFactory = mock(SSLFactory.class);
285         SSLContext sslContext = mock(SSLContext.class);
286         SSLSessionContext clientSessionContext = mock(SSLSessionContext.class);
287         SSLSession clientSession = mock(SSLSession.class);
288 
289         when(sslFactory.getSslContext()).thenReturn(sslContext);
290         when(sslContext.getClientSessionContext()).thenReturn(clientSessionContext);
291 
292         when(clientSessionContext.getIds()).thenReturn(Collections.enumeration(Collections.singletonList(new byte[]{1})));
293         when(clientSessionContext.getSession(any())).thenReturn(clientSession);
294 
295         List<SSLSession> clientSslSessions = SSLSessionUtils.getClientSslSessions(sslFactory);
296 
297         assertThat(clientSslSessions).hasSize(1);
298     }
299 
300     @Test
301     void throwExceptionWhenUpdateSessionTimeoutWithInvalidCacheSize() {
302         SSLFactory sslFactory = mock(SSLFactory.class);
303         SSLContext sslContext = mock(SSLContext.class);
304 
305         when(sslFactory.getSslContext()).thenReturn(sslContext);
306 
307         assertThatThrownBy(() -> SSLSessionUtils.updateSessionTimeout(sslFactory, -1))
308                 .isInstanceOf(IllegalArgumentException.class)
309                 .hasMessage("Unsupported timeout has been provided. Timeout should be equal or greater than [0], but received [-1]");
310     }
311 
312     @Test
313     void throwExceptionWhenUpdateSessionCacheSizeWithInvalidCacheSize() {
314         SSLFactory sslFactory = mock(SSLFactory.class);
315         SSLContext sslContext = mock(SSLContext.class);
316 
317         when(sslFactory.getSslContext()).thenReturn(sslContext);
318 
319         assertThatThrownBy(() -> SSLSessionUtils.updateSessionCacheSize(sslFactory, -1))
320                 .isInstanceOf(IllegalArgumentException.class)
321                 .hasMessage("Unsupported cache size has been provided. Cache size should be equal or greater than [0], but received [-1]");
322     }
323 
324 
325 
326 }