Use a consistent provisioning address

This commit is contained in:
Jon Chambers 2024-10-01 13:34:37 -04:00 committed by GitHub
parent b284e95394
commit 26503dffdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 8 deletions

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
@ -43,7 +44,7 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
final String provisioningAddress = UUID.randomUUID().toString();
final String provisioningAddress = generateProvisioningAddress();
context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress));
provisioningManager.addListener(provisioningAddress, message -> {
@ -56,15 +57,16 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
});
context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()),
Optional.of(generateProvisioningAddress().toByteArray()));
Optional.of(MessageProtos.ProvisioningAddress.newBuilder()
.setAddress(provisioningAddress)
.build().toByteArray()));
}
private static MessageProtos.ProvisioningAddress generateProvisioningAddress() {
@VisibleForTesting
public static String generateProvisioningAddress() {
final byte[] provisioningAddress = new byte[16];
new SecureRandom().nextBytes(provisioningAddress);
return MessageProtos.ProvisioningAddress.newBuilder()
.setAddress(Base64.getUrlEncoder().encodeToString(provisioningAddress))
.build();
return Base64.getUrlEncoder().encodeToString(provisioningAddress);
}
}

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
@ExtendWith(DropwizardExtensionsSupport.class)
class ProvisioningControllerTest {
@ -64,7 +65,7 @@ class ProvisioningControllerTest {
@Test
void sendProvisioningMessage() {
final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16);
final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress();
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true);
@ -84,7 +85,7 @@ class ProvisioningControllerTest {
@Test
void sendProvisioningMessageRateLimited() throws RateLimitExceededException {
final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16);
final String provisioningAddress = ProvisioningConnectListener.generateProvisioningAddress();
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
doThrow(new RateLimitExceededException(Duration.ZERO))

View File

@ -0,0 +1,63 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.InvalidProtocolBufferException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
class ProvisioningConnectListenerTest {
private ProvisioningManager provisioningManager;
private ProvisioningConnectListener provisioningConnectListener;
@BeforeEach
void setUp() {
provisioningManager = mock(ProvisioningManager.class);
provisioningConnectListener = new ProvisioningConnectListener(provisioningManager);
}
@Test
void onWebSocketConnect() {
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
final WebSocketSessionContext context = new WebSocketSessionContext(webSocketClient);
provisioningConnectListener.onWebSocketConnect(context);
context.notifyClosed(1000, "Test");
final ArgumentCaptor<String> addListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class);
final ArgumentCaptor<String> removeListenerProvisioningAddressCaptor = ArgumentCaptor.forClass(String.class);
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> sendAddressCaptor =
ArgumentCaptor.forClass(Optional.class);
verify(provisioningManager).addListener(addListenerProvisioningAddressCaptor.capture(), any());
verify(provisioningManager).removeListener(removeListenerProvisioningAddressCaptor.capture());
verify(webSocketClient).sendRequest(eq("PUT"), eq("/v1/address"), any(), sendAddressCaptor.capture());
final String sentProvisioningAddress = sendAddressCaptor.getValue()
.map(provisioningAddressBytes -> {
try {
return MessageProtos.ProvisioningAddress.parseFrom(provisioningAddressBytes);
} catch (final InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
})
.map(MessageProtos.ProvisioningAddress::getAddress)
.orElseThrow();
assertEquals(addListenerProvisioningAddressCaptor.getValue(), removeListenerProvisioningAddressCaptor.getValue());
assertEquals(addListenerProvisioningAddressCaptor.getValue(), sentProvisioningAddress);
}
}