From b2211de8d8cbbbe560d30678758c9a529cb2de93 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 30 Sep 2024 11:37:37 -0400 Subject: [PATCH] Retire `ProvisioningAddress` and `WebsocketAddress` --- .../controllers/ProvisioningController.java | 7 +- .../push/ProvisioningManager.java | 26 +++---- .../textsecuregcm/storage/PubSubAddress.java | 11 --- .../websocket/ProvisioningAddress.java | 37 ---------- .../ProvisioningConnectListener.java | 20 ++++-- .../websocket/WebsocketAddress.java | 68 ------------------- .../ProvisioningControllerTest.java | 26 ++----- .../push/ProvisioningManagerTest.java | 18 +++-- 8 files changed, 42 insertions(+), 171 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java index a7eee615a..c072fe342 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java @@ -34,7 +34,6 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ProvisioningManager; -import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; import org.whispersystems.websocket.auth.ReadOnly; /** @@ -96,8 +95,10 @@ public class ProvisioningController { rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid()); - if (!provisioningManager.sendProvisioningMessage(ProvisioningAddress.create(provisioningAddress), - Base64.getMimeDecoder().decode(message.body()))) { + final boolean subscriberPresent = + provisioningManager.sendProvisioningMessage(provisioningAddress, Base64.getMimeDecoder().decode(message.body())); + + if (!subscriberPresent) { throw new WebApplicationException(Response.Status.NOT_FOUND); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java index 57251f5f3..bddd25985 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ProvisioningManager.java @@ -28,8 +28,6 @@ import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguratio import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil; -import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException; -import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; public class ProvisioningManager extends RedisPubSubAdapter implements Managed { @@ -39,7 +37,7 @@ public class ProvisioningManager extends RedisPubSubAdapter impl private final CircuitBreaker circuitBreaker; - private final Map> listenersByProvisioningAddress = + private final Map> listenersByProvisioningAddress = new ConcurrentHashMap<>(); private static final String ACTIVE_LISTENERS_GAUGE_NAME = name(ProvisioningManager.class, "activeListeners"); @@ -82,21 +80,21 @@ public class ProvisioningManager extends RedisPubSubAdapter impl redisClient.shutdown(); } - public void addListener(final ProvisioningAddress address, final Consumer listener) { + public void addListener(final String address, final Consumer listener) { listenersByProvisioningAddress.put(address, listener); circuitBreaker.executeRunnable( - () -> subscriptionConnection.sync().subscribe(address.serialize().getBytes(StandardCharsets.UTF_8))); + () -> subscriptionConnection.sync().subscribe(address.getBytes(StandardCharsets.UTF_8))); } - public void removeListener(final ProvisioningAddress address) { + public void removeListener(final String address) { RedisOperation.unchecked(() -> circuitBreaker.executeRunnable( - () -> subscriptionConnection.sync().unsubscribe(address.serialize().getBytes(StandardCharsets.UTF_8)))); + () -> subscriptionConnection.sync().unsubscribe(address.getBytes(StandardCharsets.UTF_8)))); listenersByProvisioningAddress.remove(address); } - public boolean sendProvisioningMessage(final ProvisioningAddress address, final byte[] body) { + public boolean sendProvisioningMessage(final String address, final byte[] body) { final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.newBuilder() .setType(PubSubProtos.PubSubMessage.Type.DELIVER) .setContent(ByteString.copyFrom(body)) @@ -104,7 +102,7 @@ public class ProvisioningManager extends RedisPubSubAdapter impl final boolean receiverPresent = circuitBreaker.executeSupplier( () -> publicationConnection.sync() - .publish(address.serialize().getBytes(StandardCharsets.UTF_8), pubSubMessage.toByteArray()) > 0); + .publish(address.getBytes(StandardCharsets.UTF_8), pubSubMessage.toByteArray()) > 0); Metrics.counter(SEND_PROVISIONING_MESSAGE_COUNTER_NAME, "online", String.valueOf(receiverPresent)).increment(); @@ -114,7 +112,7 @@ public class ProvisioningManager extends RedisPubSubAdapter impl @Override public void message(final byte[] channel, final byte[] message) { try { - final ProvisioningAddress address = new ProvisioningAddress(new String(channel, StandardCharsets.UTF_8)); + final String address = new String(channel, StandardCharsets.UTF_8); final PubSubProtos.PubSubMessage pubSubMessage = PubSubProtos.PubSubMessage.parseFrom(message); if (pubSubMessage.getType() == PubSubProtos.PubSubMessage.Type.DELIVER) { @@ -129,8 +127,6 @@ public class ProvisioningManager extends RedisPubSubAdapter impl Metrics.counter(RECEIVE_PROVISIONING_MESSAGE_COUNTER_NAME, "listenerPresent", String.valueOf(listenerPresent)).increment(); } - } catch (final InvalidWebsocketAddressException e) { - logger.warn("Failed to parse provisioning address", e); } catch (final InvalidProtocolBufferException e) { logger.warn("Failed to parse pub/sub message", e); } @@ -138,10 +134,6 @@ public class ProvisioningManager extends RedisPubSubAdapter impl @Override public void unsubscribed(final byte[] channel, final long count) { - try { - listenersByProvisioningAddress.remove(new ProvisioningAddress(new String(channel))); - } catch (final InvalidWebsocketAddressException e) { - logger.warn("Failed to parse provisioning address for `unsubscribe` event", e); - } + listenersByProvisioningAddress.remove(new String(channel, StandardCharsets.UTF_8)); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java deleted file mode 100644 index f63ea86ff..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java +++ /dev/null @@ -1,11 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -public interface PubSubAddress { - - String serialize(); -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java deleted file mode 100644 index f726251d3..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.websocket; - -import java.security.SecureRandom; -import java.util.Base64; - -public class ProvisioningAddress extends WebsocketAddress { - - public static byte DEVICE_ID = 0; - - public static ProvisioningAddress create(String address) { - return new ProvisioningAddress(address, DEVICE_ID); - } - - private ProvisioningAddress(String address, byte deviceId) { - super(address, deviceId); - } - - public ProvisioningAddress(String serialized) throws InvalidWebsocketAddressException { - super(serialized); - } - - public String getAddress() { - return getNumber(); - } - - public static ProvisioningAddress generate() { - byte[] random = new byte[16]; - new SecureRandom().nextBytes(random); - - return create(Base64.getUrlEncoder().withoutPadding().encodeToString(random)); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java index bdf30aa44..100c74bd4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.websocket; +import com.google.common.annotations.VisibleForTesting; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; @@ -13,8 +14,11 @@ import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; +import java.security.SecureRandom; +import java.util.Base64; import java.util.List; import java.util.Optional; +import java.util.UUID; /** * A "provisioning WebSocket" provides a mechanism for sending a caller-defined provisioning message from the primary @@ -40,7 +44,7 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { @Override public void onWebSocketConnect(WebSocketSessionContext context) { - final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate(); + final String provisioningAddress = UUID.randomUUID().toString(); context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress)); provisioningManager.addListener(provisioningAddress, message -> { @@ -53,9 +57,15 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { }); context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()), - Optional.of(MessageProtos.ProvisioningUuid.newBuilder() - .setUuid(provisioningAddress.getAddress()) - .build() - .toByteArray())); + Optional.of(generateProvisioningUuid().toByteArray())); + } + + private static MessageProtos.ProvisioningUuid generateProvisioningUuid() { + final byte[] provisioningAddress = new byte[16]; + new SecureRandom().nextBytes(provisioningAddress); + + return MessageProtos.ProvisioningUuid.newBuilder() + .setUuid(Base64.getUrlEncoder().encodeToString(provisioningAddress)) + .build(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java deleted file mode 100644 index 054479865..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.websocket; - -import org.whispersystems.textsecuregcm.storage.PubSubAddress; - -public class WebsocketAddress implements PubSubAddress { - - private final String number; - private final byte deviceId; - - public WebsocketAddress(String number, byte deviceId) { - this.number = number; - this.deviceId = deviceId; - } - - public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException { - try { - String[] parts = serialized.split(":", 2); - - if (parts.length != 2) { - throw new InvalidWebsocketAddressException("Bad address: " + serialized); - } - - this.number = parts[0]; - this.deviceId = Byte.parseByte(parts[1]); - } catch (NumberFormatException e) { - throw new InvalidWebsocketAddressException(e); - } - } - - public String getNumber() { - return number; - } - - public byte getDeviceId() { - return deviceId; - } - - public String serialize() { - return number + ":" + deviceId; - } - - public String toString() { - return serialize(); - } - - @Override - public boolean equals(Object other) { - if (other == null) return false; - if (!(other instanceof WebsocketAddress)) return false; - - WebsocketAddress that = (WebsocketAddress)other; - - return - this.number.equals(that.number) && - this.deviceId == that.deviceId; - } - - @Override - public int hashCode() { - return number.hashCode() ^ (int)deviceId; - } - -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java index 31ee64c43..e2b831e65 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.controllers; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; @@ -21,15 +20,14 @@ import io.dropwizard.testing.junit5.ResourceExtension; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Base64; -import java.util.UUID; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.apache.commons.lang3.RandomStringUtils; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; import org.whispersystems.textsecuregcm.limits.RateLimiter; @@ -38,7 +36,6 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; -import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; @ExtendWith(DropwizardExtensionsSupport.class) class ProvisioningControllerTest { @@ -67,13 +64,13 @@ class ProvisioningControllerTest { @Test void sendProvisioningMessage() { - final String destination = UUID.randomUUID().toString(); + final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16); final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true); try (final Response response = RESOURCE_EXTENSION.getJerseyTest() - .target("/v1/provisioning/" + destination) + .target("/v1/provisioning/" + provisioningAddress) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)), @@ -81,31 +78,20 @@ class ProvisioningControllerTest { assertEquals(Response.Status.NO_CONTENT.getStatusCode(), response.getStatus()); - final ArgumentCaptor provisioningAddressCaptor = - ArgumentCaptor.forClass(ProvisioningAddress.class); - - final ArgumentCaptor provisioningMessageCaptor = ArgumentCaptor.forClass(byte[].class); - - verify(provisioningManager).sendProvisioningMessage(provisioningAddressCaptor.capture(), - provisioningMessageCaptor.capture()); - - assertEquals(destination, provisioningAddressCaptor.getValue().getAddress()); - assertEquals(ProvisioningAddress.DEVICE_ID, provisioningAddressCaptor.getValue().getDeviceId()); - - assertArrayEquals(messageBody, provisioningMessageCaptor.getValue()); + verify(provisioningManager).sendProvisioningMessage(provisioningAddress, messageBody); } } @Test void sendProvisioningMessageRateLimited() throws RateLimitExceededException { - final String destination = UUID.randomUUID().toString(); + final String provisioningAddress = RandomStringUtils.randomAlphanumeric(16); final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); doThrow(new RateLimitExceededException(Duration.ZERO)) .when(messagesRateLimiter).validate(AuthHelper.VALID_UUID); try (final Response response = RESOURCE_EXTENSION.getJerseyTest() - .target("/v1/provisioning/" + destination) + .target("/v1/provisioning/" + provisioningAddress) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java index 72e1e8cc0..2633b8d8c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java @@ -8,6 +8,7 @@ import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import com.google.protobuf.ByteString; +import java.util.UUID; import java.util.function.Consumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -18,7 +19,6 @@ import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguratio import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.TestRandomUtil; -import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; class ProvisioningManagerTest { @@ -42,14 +42,13 @@ class ProvisioningManagerTest { @Test void sendProvisioningMessage() { - final ProvisioningAddress address = ProvisioningAddress.create("address"); - + final String provisioningAddress = UUID.randomUUID().toString(); final byte[] content = TestRandomUtil.nextBytes(16); @SuppressWarnings("unchecked") final Consumer subscribedConsumer = mock(Consumer.class); - provisioningManager.addListener(address, subscribedConsumer); - provisioningManager.sendProvisioningMessage(address, content); + provisioningManager.addListener(provisioningAddress, subscribedConsumer); + provisioningManager.sendProvisioningMessage(provisioningAddress, content); final ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class); @@ -62,15 +61,14 @@ class ProvisioningManagerTest { @Test void removeListener() { - final ProvisioningAddress address = ProvisioningAddress.create("address"); - + final String provisioningAddress = UUID.randomUUID().toString(); final byte[] content = TestRandomUtil.nextBytes(16); @SuppressWarnings("unchecked") final Consumer subscribedConsumer = mock(Consumer.class); - provisioningManager.addListener(address, subscribedConsumer); - provisioningManager.removeListener(address); - provisioningManager.sendProvisioningMessage(address, content); + provisioningManager.addListener(provisioningAddress, subscribedConsumer); + provisioningManager.removeListener(provisioningAddress); + provisioningManager.sendProvisioningMessage(provisioningAddress, content); // Make sure that we give the message enough time to show up (if it was going to) before declaring victory verify(subscribedConsumer, after(PUBSUB_TIMEOUT_MILLIS).never()).accept(any());