Use a consistent provisioning address
This commit is contained in:
parent
b284e95394
commit
26503dffdf
|
@ -5,6 +5,7 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
|
||||
|
@ -43,7 +44,7 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
|
|||
|
||||
@Override
|
||||
public void onWebSocketConnect(WebSocketSessionContext context) {
|
||||
final String provisioningAddress = UUID.randomUUID().toString();
|
||||
final String provisioningAddress = generateProvisioningAddress();
|
||||
context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress));
|
||||
|
||||
provisioningManager.addListener(provisioningAddress, message -> {
|
||||
|
@ -56,15 +57,16 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
|
|||
});
|
||||
|
||||
context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()),
|
||||
Optional.of(generateProvisioningAddress().toByteArray()));
|
||||
Optional.of(MessageProtos.ProvisioningAddress.newBuilder()
|
||||
.setAddress(provisioningAddress)
|
||||
.build().toByteArray()));
|
||||
}
|
||||
|
||||
private static MessageProtos.ProvisioningAddress generateProvisioningAddress() {
|
||||
@VisibleForTesting
|
||||
public static String generateProvisioningAddress() {
|
||||
final byte[] provisioningAddress = new byte[16];
|
||||
new SecureRandom().nextBytes(provisioningAddress);
|
||||
|
||||
return MessageProtos.ProvisioningAddress.newBuilder()
|
||||
.setAddress(Base64.getUrlEncoder().encodeToString(provisioningAddress))
|
||||
.build();
|
||||
return Base64.getUrlEncoder().encodeToString(provisioningAddress);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
|
|||
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class ProvisioningControllerTest {
|
||||
|
@ -64,7 +65,7 @@ class ProvisioningControllerTest {
|
|||
|
||||
@Test
|
||||
void sendProvisioningMessage() {
|
||||
final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16);
|
||||
final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress();
|
||||
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true);
|
||||
|
@ -84,7 +85,7 @@ class ProvisioningControllerTest {
|
|||
|
||||
@Test
|
||||
void sendProvisioningMessageRateLimited() throws RateLimitExceededException {
|
||||
final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16);
|
||||
final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress();
|
||||
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO))
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
package org.whispersystems.textsecuregcm.websocket;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
|
||||
import org.whispersystems.websocket.WebSocketClient;
|
||||
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
class ProvisioningConnectListenerTest {
|
||||
|
||||
private ProvisioningManager provisioningManager;
|
||||
private ProvisioningConnectListener provisioningConnectListener;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
provisioningManager = mock(ProvisioningManager.class);
|
||||
provisioningConnectListener = new ProvisioningConnectListener(provisioningManager);
|
||||
}
|
||||
|
||||
@Test
|
||||
void onWebSocketConnect() {
|
||||
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
|
||||
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
|
||||
|
||||
provisioningConnectListener.onWebSocketConnect(context);
|
||||
context.notifyClosed(1000, "Test");
|
||||
|
||||
final ArgumentCaptor<String> addListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class);
|
||||
final ArgumentCaptor<String> removeListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class);
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> sendAddressCaptor =
|
||||
ArgumentCaptor.forClass(Optional.class);
|
||||
|
||||
verify(provisioningManager).addListener(addListenerProvisioningAddressCaptor.capture(), any());
|
||||
verify(provisioningManager).removeListener(removeListenerProvisioningAddressCaptor.capture());
|
||||
verify(webSocketClient).sendRequest(eq("PUT"), eq("/v1/address"), any(), sendAddressCaptor.capture());
|
||||
|
||||
final String sentProvisioningAddress = sendAddressCaptor.getValue()
|
||||
.map(provisioningAddressBytes -> {
|
||||
try {
|
||||
return MessageProtos.ProvisioningAddress.parseFrom(provisioningAddressBytes);
|
||||
} catch (final InvalidProtocolBufferException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
})
|
||||
.map(MessageProtos.ProvisioningAddress::getAddress)
|
||||
.orElseThrow();
|
||||
|
||||
assertEquals(addListenerProvisioningAddressCaptor.getValue(), removeListenerProvisioningAddressCaptor.getValue());
|
||||
assertEquals(addListenerProvisioningAddressCaptor.getValue(), sentProvisioningAddress);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue