diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java index bc817dcb0..dd4e7993e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.websocket; +import com.google.common.annotations.VisibleForTesting; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; @@ -43,7 +44,7 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { @Override public void onWebSocketConnect(WebSocketSessionContext context) { - final String provisioningAddress = UUID.randomUUID().toString(); + final String provisioningAddress = generateProvisioningAddress(); context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress)); provisioningManager.addListener(provisioningAddress, message -> { @@ -56,15 +57,16 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { }); context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()), - Optional.of(generateProvisioningAddress().toByteArray())); + Optional.of(MessageProtos.ProvisioningAddress.newBuilder() + .setAddress(provisioningAddress) + .build().toByteArray())); } - private static MessageProtos.ProvisioningAddress generateProvisioningAddress() { + @VisibleForTesting + public static String generateProvisioningAddress() { final byte[] provisioningAddress = new byte[16]; new SecureRandom().nextBytes(provisioningAddress); - return MessageProtos.ProvisioningAddress.newBuilder() - .setAddress(Base64.getUrlEncoder().encodeToString(provisioningAddress)) - .build(); + return Base64.getUrlEncoder().encodeToString(provisioningAddress); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java index e2b831e65..a16ce21d8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java @@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener; @ExtendWith(DropwizardExtensionsSupport.class) class ProvisioningControllerTest { @@ -64,7 +65,7 @@ class ProvisioningControllerTest { @Test void sendProvisioningMessage() { - final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16); + final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress(); final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true); @@ -84,7 +85,7 @@ class ProvisioningControllerTest { @Test void sendProvisioningMessageRateLimited() throws RateLimitExceededException { - final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16); + final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress(); final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); doThrow(new RateLimitExceededException(Duration.ZERO)) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java new file mode 100644 index 000000000..bff2bf74e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java @@ -0,0 +1,63 @@ +package org.whispersystems.textsecuregcm.websocket; + +import com.google.protobuf.InvalidProtocolBufferException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.push.ProvisioningManager; +import org.whispersystems.websocket.WebSocketClient; +import org.whispersystems.websocket.session.WebSocketSessionContext; + +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +class ProvisioningConnectListenerTest { + + private ProvisioningManager provisioningManager; + private ProvisioningConnectListener provisioningConnectListener; + + @BeforeEach + void setUp() { + provisioningManager = mock(ProvisioningManager.class); + provisioningConnectListener = new ProvisioningConnectListener(provisioningManager); + } + + @Test + void onWebSocketConnect() { + final WebSocketClient webSocketClient = mock(WebSocketClient.class); + final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient); + + provisioningConnectListener.onWebSocketConnect(context); + context.notifyClosed(1000, "Test"); + + final ArgumentCaptor addListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class); + final ArgumentCaptor removeListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class); + + @SuppressWarnings("unchecked") final ArgumentCaptor> sendAddressCaptor = + ArgumentCaptor.forClass(Optional.class); + + verify(provisioningManager).addListener(addListenerProvisioningAddressCaptor.capture(), any()); + verify(provisioningManager).removeListener(removeListenerProvisioningAddressCaptor.capture()); + verify(webSocketClient).sendRequest(eq("PUT"), eq("/v1/address"), any(), sendAddressCaptor.capture()); + + final String sentProvisioningAddress = sendAddressCaptor.getValue() + .map(provisioningAddressBytes -> { + try { + return MessageProtos.ProvisioningAddress.parseFrom(provisioningAddressBytes); + } catch (final InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .map(MessageProtos.ProvisioningAddress::getAddress) + .orElseThrow(); + + assertEquals(addListenerProvisioningAddressCaptor.getValue(), removeListenerProvisioningAddressCaptor.getValue()); + assertEquals(addListenerProvisioningAddressCaptor.getValue(), sentProvisioningAddress); + } +}