From 977243ebfddd53e6f45ecd4578a906ac141bf8c0 Mon Sep 17 00:00:00 2001
From: Sergey Skrobotov
Date: Fri, 8 Sep 2023 16:08:59 -0700
Subject: [PATCH] DRY gRPC tests, refactor error mapping
---
.../textsecuregcm/WhisperServerService.java | 14 +-
.../RateLimitExceededException.java | 35 +++-
.../grpc/CallingGrpcService.java | 5 -
.../grpc/ConvertibleToGrpcStatus.java | 20 ++
.../grpc/ErrorMappingInterceptor.java | 49 +++++
.../textsecuregcm/grpc/KeysGrpcService.java | 5 -
.../grpc/ProfileGrpcService.java | 5 -
.../textsecuregcm/grpc/RateLimitUtil.java | 45 -----
.../grpc/CallingGrpcServiceTest.java | 65 ++-----
.../grpc/DevicesGrpcServiceTest.java | 99 ++++------
.../textsecuregcm/grpc/GrpcTestUtils.java | 65 +++++++
.../grpc/KeysAnonymousGrpcServiceTest.java | 77 +++-----
.../grpc/KeysGrpcServiceTest.java | 139 +++++---------
.../grpc/ProfileAnonymousGrpcServiceTest.java | 105 +++++-----
.../grpc/ProfileGrpcServiceTest.java | 180 ++++++------------
.../grpc/SimpleBaseGrpcTest.java | 146 ++++++++++++++
.../limits/RateLimitersTest.java | 32 ++--
.../textsecuregcm/util/MockUtils.java | 76 ++++++--
18 files changed, 637 insertions(+), 525 deletions(-)
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/ConvertibleToGrpcStatus.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/ErrorMappingInterceptor.java
delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index 080d03819..1b7185a8f 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -120,6 +120,7 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor;
+import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper;
import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
@@ -644,8 +645,6 @@ public class WhisperServerService extends Application grpcServer = ServerBuilder.forPort(config.getGrpcPort())
- // TODO: specialize metrics with user-agent platform
- .intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keys))
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
@@ -657,13 +656,16 @@ public class WhisperServerService extends Application RETRY_AFTER_DURATION_KEY =
+ Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
+ @Override
+ public String toAsciiString(final Duration value) {
+ return value.toString();
+ }
+
+ @Override
+ public Duration parseAsciiString(final String serialized) {
+ return Duration.parse(serialized);
+ }
+ });
@Nullable
private final Duration retryDuration;
@@ -33,4 +49,19 @@ public class RateLimitExceededException extends Exception {
public boolean isLegacy() {
return legacy;
}
+
+ @Override
+ public Status grpcStatus() {
+ return Status.RESOURCE_EXHAUSTED;
+ }
+
+ @Override
+ public Optional grpcMetadata() {
+ return getRetryDuration()
+ .map(duration -> {
+ final Metadata metadata = new Metadata();
+ metadata.put(RETRY_AFTER_DURATION_KEY, duration);
+ return metadata;
+ });
+ }
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java
index f981f094d..7acab7547 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java
@@ -24,11 +24,6 @@ public class CallingGrpcService extends ReactorCallingGrpc.CallingImplBase {
this.rateLimiters = rateLimiters;
}
- @Override
- protected Throwable onErrorMap(final Throwable throwable) {
- return RateLimitUtil.mapRateLimitExceededException(throwable);
- }
-
@Override
public Mono getTurnCredentials(final GetTurnCredentialsRequest request) {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ConvertibleToGrpcStatus.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ConvertibleToGrpcStatus.java
new file mode 100644
index 000000000..ce0cc00fb
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ConvertibleToGrpcStatus.java
@@ -0,0 +1,20 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import io.grpc.Metadata;
+import io.grpc.Status;
+import java.util.Optional;
+
+/**
+ * Interface to be imlemented by our custom exceptions that are consistently mapped to a gRPC status.
+ */
+public interface ConvertibleToGrpcStatus {
+
+ Status grpcStatus();
+
+ Optional grpcMetadata();
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ErrorMappingInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ErrorMappingInterceptor.java
new file mode 100644
index 000000000..75d25d498
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ErrorMappingInterceptor.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import io.grpc.ForwardingServerCall;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+
+/**
+ * This interceptor observes responses from the service and if the response status is {@link Status#UNKNOWN}
+ * and there is a non-null cause which is an instance of {@link ConvertibleToGrpcStatus},
+ * then status and metadata to be returned to the client is resolved from that object.
+ *
+ * This eliminates the need of having each service to override {@code `onErrorMap()`} method for commonly used exceptions.
+ */
+public class ErrorMappingInterceptor implements ServerInterceptor {
+
+ @Override
+ public ServerCall.Listener interceptCall(
+ final ServerCall call,
+ final Metadata headers,
+ final ServerCallHandler next) {
+ return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall<>(call) {
+ @Override
+ public void close(final Status status, final Metadata trailers) {
+ // The idea is to only apply the automatic conversion logic in the cases
+ // when there was no explicit decision by the service to provide a status.
+ // I.e. if at this point we see anything but the `UNKNOWN`,
+ // that means that some logic in the service made this decision already
+ // and automatic conversion may conflict with it.
+ if (status.getCode().equals(Status.Code.UNKNOWN)
+ && status.getCause() instanceof ConvertibleToGrpcStatus convertibleToGrpcStatus) {
+ super.close(
+ convertibleToGrpcStatus.grpcStatus(),
+ convertibleToGrpcStatus.grpcMetadata().orElseGet(Metadata::new)
+ );
+ } else {
+ super.close(status, trailers);
+ }
+ }
+ }, headers);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java
index 9ced01fbf..6d3c3ccf9 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java
@@ -74,11 +74,6 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
this.rateLimiters = rateLimiters;
}
- @Override
- protected Throwable onErrorMap(final Throwable throwable) {
- return RateLimitUtil.mapRateLimitExceededException(throwable);
- }
-
@Override
public Mono getPreKeyCount(final GetPreKeyCountRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java
index e9dd35e46..7c26a4582 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java
@@ -100,11 +100,6 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
this.bucket = bucket;
}
- @Override
- protected Throwable onErrorMap(final Throwable throwable) {
- return RateLimitUtil.mapRateLimitExceededException(throwable);
- }
-
@Override
public Mono setProfile(final SetProfileRequest request) {
validateRequest(request);
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java
deleted file mode 100644
index dadfd5417..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright 2023 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-package org.whispersystems.textsecuregcm.grpc;
-
-import io.grpc.Metadata;
-import io.grpc.Status;
-import io.grpc.StatusException;
-import java.time.Duration;
-import javax.annotation.Nullable;
-import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
-
-public class RateLimitUtil {
-
- public static final Metadata.Key RETRY_AFTER_DURATION_KEY =
- Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
- @Override
- public String toAsciiString(final Duration value) {
- return value.toString();
- }
-
- @Override
- public Duration parseAsciiString(final String serialized) {
- return Duration.parse(serialized);
- }
- });
-
- public static Throwable mapRateLimitExceededException(final Throwable throwable) {
- if (throwable instanceof RateLimitExceededException rateLimitExceededException) {
- @Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration()
- .map(duration -> {
- final Metadata metadata = new Metadata();
- metadata.put(RETRY_AFTER_DURATION_KEY, duration);
-
- return metadata;
- }).orElse(null);
-
- return new StatusException(Status.RESOURCE_EXHAUSTED, trailers);
- }
-
- return throwable;
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java
index 0753dfadf..3f75c7eb9 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java
@@ -6,66 +6,41 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
-import io.grpc.ServerInterceptors;
-import io.grpc.Status;
-import io.grpc.StatusRuntimeException;
import java.time.Duration;
import java.util.List;
-import java.util.UUID;
-import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
+import org.mockito.Mock;
import org.signal.chat.calling.CallingGrpc;
import org.signal.chat.calling.GetTurnCredentialsRequest;
import org.signal.chat.calling.GetTurnCredentialsResponse;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
-import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
-import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
-import org.whispersystems.textsecuregcm.storage.Device;
-import reactor.core.publisher.Mono;
+import org.whispersystems.textsecuregcm.util.MockUtils;
-class CallingGrpcServiceTest {
+class CallingGrpcServiceTest extends SimpleBaseGrpcTest {
+ @Mock
private TurnTokenGenerator turnTokenGenerator;
+
+ @Mock
private RateLimiter turnCredentialRateLimiter;
- private CallingGrpc.CallingBlockingStub callingStub;
-
- private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
- private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
-
- @RegisterExtension
- static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
-
- @BeforeEach
- void setUp() {
- turnTokenGenerator = mock(TurnTokenGenerator.class);
- turnCredentialRateLimiter = mock(RateLimiter.class);
+ @Override
+ protected CallingGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getTurnLimiter()).thenReturn(turnCredentialRateLimiter);
-
- final CallingGrpcService callingGrpcService = new CallingGrpcService(turnTokenGenerator, rateLimiters);
-
- final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
-
- GRPC_SERVER_EXTENSION.getServiceRegistry()
- .addService(ServerInterceptors.intercept(callingGrpcService, mockAuthenticationInterceptor));
-
- callingStub = CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
+ return new CallingGrpcService(turnTokenGenerator, rateLimiters);
}
@Test
@@ -74,10 +49,10 @@ class CallingGrpcServiceTest {
final String password = "test-password";
final List urls = List.of("first", "second");
- when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)).thenReturn(Mono.empty());
+ MockUtils.updateRateLimiterResponseToAllow(turnCredentialRateLimiter, AUTHENTICATED_ACI);
when(turnTokenGenerator.generate(any())).thenReturn(new TurnToken(username, password, urls));
- final GetTurnCredentialsResponse response = callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build());
+ final GetTurnCredentialsResponse response = authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build());
final GetTurnCredentialsResponse expectedResponse = GetTurnCredentialsResponse.newBuilder()
.setUsername(username)
@@ -90,20 +65,10 @@ class CallingGrpcServiceTest {
@Test
void getTurnCredentialsRateLimited() {
- final Duration retryAfter = Duration.ofMinutes(19);
-
- when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI))
- .thenReturn(Mono.error(new RateLimitExceededException(retryAfter, false)));
-
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
-
+ final Duration retryAfter = MockUtils.updateRateLimiterResponseToFail(
+ turnCredentialRateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(19), false);
+ assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
verify(turnTokenGenerator, never()).generate(any());
-
- assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
- assertNotNull(exception.getTrailers());
- assertEquals(retryAfter, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
-
verifyNoInteractions(turnTokenGenerator);
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java
index 4effc71f5..b5a5e673c 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java
@@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
@@ -14,11 +13,10 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
-import io.grpc.ServerInterceptors;
import io.grpc.Status;
-import io.grpc.StatusRuntimeException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
@@ -26,21 +24,19 @@ import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
-import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
-import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
+import org.mockito.Mock;
import org.signal.chat.device.ClearPushTokenRequest;
import org.signal.chat.device.ClearPushTokenResponse;
import org.signal.chat.device.DevicesGrpc;
@@ -54,42 +50,31 @@ import org.signal.chat.device.SetDeviceNameRequest;
import org.signal.chat.device.SetDeviceNameResponse;
import org.signal.chat.device.SetPushTokenRequest;
import org.signal.chat.device.SetPushTokenResponse;
-import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
-class DevicesGrpcServiceTest {
+class DevicesGrpcServiceTest extends SimpleBaseGrpcTest {
+ @Mock
private AccountsManager accountsManager;
+
+ @Mock
private KeysManager keysManager;
+
+ @Mock
private MessagesManager messagesManager;
+ @Mock
private Account authenticatedAccount;
- private MockAuthenticationInterceptor mockAuthenticationInterceptor;
- private DevicesGrpc.DevicesBlockingStub devicesStub;
- @RegisterExtension
- static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
-
- private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
- private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
-
- @BeforeEach
- void setUp() {
- accountsManager = mock(AccountsManager.class);
- keysManager = mock(KeysManager.class);
- messagesManager = mock(MessagesManager.class);
-
- authenticatedAccount = mock(Account.class);
+ @Override
+ protected DevicesGrpcService createServiceBeforeEachTest() {
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
- mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
-
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI))
.thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
@@ -117,11 +102,7 @@ class DevicesGrpcServiceTest {
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
- final DevicesGrpcService devicesGrpcService = new DevicesGrpcService(accountsManager, keysManager, messagesManager);
- devicesStub = DevicesGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
-
- GRPC_SERVER_EXTENSION.getServiceRegistry()
- .addService(ServerInterceptors.intercept(devicesGrpcService, mockAuthenticationInterceptor));
+ return new DevicesGrpcService(accountsManager, keysManager, messagesManager);
}
@Test
@@ -161,14 +142,14 @@ class DevicesGrpcServiceTest {
.build())
.build();
- assertEquals(expectedResponse, devicesStub.getDevices(GetDevicesRequest.newBuilder().build()));
+ assertEquals(expectedResponse, authenticatedServiceStub().getDevices(GetDevicesRequest.newBuilder().build()));
}
@Test
void removeDevice() {
final long deviceId = 17;
- final RemoveDeviceResponse ignored = devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
+ final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(deviceId)
.build());
@@ -179,30 +160,23 @@ class DevicesGrpcServiceTest {
@Test
void removeDevicePrimary() {
- final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
- () -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
- .setId(1)
- .build()));
-
- assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
+ .setId(1)
+ .build()));
}
@Test
void removeDeviceNonPrimaryAuthenticated() {
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1);
-
- final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
- () -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
- .setId(17)
- .build()));
-
- assertEquals(Status.Code.PERMISSION_DENIED, exception.getStatus().getCode());
+ mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1);
+ assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
+ .setId(17)
+ .build()));
}
@ParameterizedTest
@ValueSource(longs = {Device.MASTER_ID, Device.MASTER_ID + 1})
void setDeviceName(final long deviceId) {
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
+ mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
@@ -210,7 +184,7 @@ class DevicesGrpcServiceTest {
final byte[] deviceName = new byte[128];
ThreadLocalRandom.current().nextBytes(deviceName);
- final SetDeviceNameResponse ignored = devicesStub.setDeviceName(SetDeviceNameRequest.newBuilder()
+ final SetDeviceNameResponse ignored = authenticatedServiceStub().setDeviceName(SetDeviceNameRequest.newBuilder()
.setName(ByteString.copyFrom(deviceName))
.build());
@@ -221,11 +195,7 @@ class DevicesGrpcServiceTest {
@MethodSource
void setDeviceNameIllegalArgument(final SetDeviceNameRequest request) {
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(mock(Device.class)));
-
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> devicesStub.setDeviceName(request));
-
- assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setDeviceName(request));
}
private static Stream setDeviceNameIllegalArgument() {
@@ -248,12 +218,12 @@ class DevicesGrpcServiceTest {
@Nullable final String expectedApnsVoipToken,
@Nullable final String expectedFcmToken) {
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
+ mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
- final SetPushTokenResponse ignored = devicesStub.setPushToken(request);
+ final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(device).setApnId(expectedApnsToken);
verify(device).setVoipApnId(expectedApnsVoipToken);
@@ -312,7 +282,7 @@ class DevicesGrpcServiceTest {
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
- final SetPushTokenResponse ignored = devicesStub.setPushToken(request);
+ final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
}
@@ -352,12 +322,7 @@ class DevicesGrpcServiceTest {
void setPushTokenIllegalArgument(final SetPushTokenRequest request) {
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
-
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> devicesStub.setPushToken(request));
-
- assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
-
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setPushToken(request));
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
}
@@ -383,7 +348,7 @@ class DevicesGrpcServiceTest {
@Nullable final String fcmToken,
@Nullable final String expectedUserAgent) {
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
+ mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
@@ -393,7 +358,7 @@ class DevicesGrpcServiceTest {
when(device.getGcmId()).thenReturn(fcmToken);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
- final ClearPushTokenResponse ignored = devicesStub.clearPushToken(ClearPushTokenRequest.newBuilder().build());
+ final ClearPushTokenResponse ignored = authenticatedServiceStub().clearPushToken(ClearPushTokenRequest.newBuilder().build());
verify(device).setApnId(null);
verify(device).setVoipApnId(null);
@@ -430,12 +395,12 @@ class DevicesGrpcServiceTest {
@CartesianTest.Values(booleans = {true, false}) final boolean pni,
@CartesianTest.Values(booleans = {true, false}) final boolean paymentActivation) {
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
+ mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
- final SetCapabilitiesResponse ignored = devicesStub.setCapabilities(SetCapabilitiesRequest.newBuilder()
+ final SetCapabilitiesResponse ignored = authenticatedServiceStub().setCapabilities(SetCapabilitiesRequest.newBuilder()
.setStorage(storage)
.setTransfer(transfer)
.setPni(pni)
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java
new file mode 100644
index 000000000..116aafc13
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+import io.grpc.BindableService;
+import io.grpc.ServerInterceptors;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import java.time.Duration;
+import java.util.UUID;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.function.Executable;
+import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
+import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+
+public final class GrpcTestUtils {
+
+ private GrpcTestUtils() {
+ // noop
+ }
+
+ public static MockAuthenticationInterceptor setupAuthenticatedExtension(
+ final GrpcServerExtension extension,
+ final UUID authenticatedAci,
+ final long authenticatedDeviceId,
+ final BindableService service) {
+ final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
+ mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
+ extension.getServiceRegistry()
+ .addService(ServerInterceptors.intercept(service, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
+ return mockAuthenticationInterceptor;
+ }
+
+ public static void setupUnauthenticatedExtension(
+ final GrpcServerExtension extension,
+ final BindableService service) {
+ extension.getServiceRegistry()
+ .addService(ServerInterceptors.intercept(service, new ErrorMappingInterceptor()));
+ }
+
+ public static void assertStatusException(final Status expected, final Executable serviceCall) {
+ final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
+ assertEquals(expected.getCode(), exception.getStatus().getCode());
+ }
+
+ public static void assertRateLimitExceeded(
+ final Duration expectedRetryAfter,
+ final Executable serviceCall,
+ final Object... mocksToCheckForNoInteraction) {
+ final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
+ assertEquals(Status.RESOURCE_EXHAUSTED, exception.getStatus());
+ assertNotNull(exception.getTrailers());
+ assertEquals(expectedRetryAfter, exception.getTrailers().get(RateLimitExceededException.RETRY_AFTER_DURATION_KEY));
+ for (final Object mock: mocksToCheckForNoInteraction) {
+ verifyNoInteractions(mock);
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java
index 35064f3be..721081c92 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java
@@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
import io.grpc.Status;
@@ -20,11 +21,10 @@ import java.util.Collections;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
-import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
+import org.mockito.Mock;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey;
@@ -48,27 +48,18 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
-class KeysAnonymousGrpcServiceTest {
+class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest {
+ @Mock
private AccountsManager accountsManager;
+
+ @Mock
private KeysManager keysManager;
- private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub;
- @RegisterExtension
- static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
-
- @BeforeEach
- void setUp() {
- accountsManager = mock(AccountsManager.class);
- keysManager = mock(KeysManager.class);
-
- final KeysAnonymousGrpcService keysGrpcService =
- new KeysAnonymousGrpcService(accountsManager, keysManager);
-
- keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
-
- GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService);
+ @Override
+ protected KeysAnonymousGrpcService createServiceBeforeEachTest() {
+ return new KeysAnonymousGrpcService(accountsManager, keysManager);
}
@Test
@@ -101,7 +92,7 @@ class KeysAnonymousGrpcServiceTest {
when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey);
- final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ final GetPreKeysResponse response = unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
@@ -152,19 +143,15 @@ class KeysAnonymousGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
- final StatusRuntimeException statusRuntimeException =
- assertThrows(StatusRuntimeException.class,
- () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
- .setRequest(GetPreKeysRequest.newBuilder()
- .setTargetIdentifier(ServiceIdentifier.newBuilder()
- .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
- .setUuid(UUIDUtil.toByteString(identifier))
- .build())
- .setDeviceId(Device.MASTER_ID)
- .build())
- .build()));
-
- assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ .setRequest(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(identifier))
+ .build())
+ .setDeviceId(Device.MASTER_ID)
+ .build())
+ .build()));
}
@Test
@@ -174,7 +161,7 @@ class KeysAnonymousGrpcServiceTest {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class,
- () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID()))
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
@@ -205,19 +192,15 @@ class KeysAnonymousGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class,
- () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
- .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
- .setRequest(GetPreKeysRequest.newBuilder()
- .setTargetIdentifier(ServiceIdentifier.newBuilder()
- .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
- .setUuid(UUIDUtil.toByteString(accountIdentifier))
- .build())
- .setDeviceId(deviceId)
- .build())
- .build()));
-
- assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
+ .setRequest(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(accountIdentifier))
+ .build())
+ .setDeviceId(deviceId)
+ .build())
+ .build()));
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java
index 1ae1a443f..b0659107f 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java
@@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
@@ -16,9 +15,10 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
-import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Duration;
@@ -32,14 +32,13 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
-import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
+import org.mockito.Mock;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey;
@@ -56,7 +55,6 @@ import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@@ -73,41 +71,34 @@ import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Mono;
-class KeysGrpcServiceTest {
-
- private AccountsManager accountsManager;
- private KeysManager keysManager;
- private RateLimiter preKeysRateLimiter;
-
- private Device authenticatedDevice;
-
- private KeysGrpc.KeysBlockingStub keysStub;
-
- private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
- private static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
- private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
+class KeysGrpcServiceTest extends SimpleBaseGrpcTest {
private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
+
private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
- @RegisterExtension
- static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
+ protected static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
- @BeforeEach
- void setUp() {
- accountsManager = mock(AccountsManager.class);
- keysManager = mock(KeysManager.class);
- preKeysRateLimiter = mock(RateLimiter.class);
+ @Mock
+ private AccountsManager accountsManager;
+ @Mock
+ private KeysManager keysManager;
+
+ @Mock
+ private RateLimiter preKeysRateLimiter;
+
+ @Mock
+ private Device authenticatedDevice;
+
+
+ @Override
+ protected KeysGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter);
when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
- final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters);
- keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
-
- authenticatedDevice = mock(Device.class);
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
final Account authenticatedAccount = mock(Account.class);
@@ -119,17 +110,13 @@ class KeysGrpcServiceTest {
when(authenticatedAccount.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
- final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
-
- GRPC_SERVER_EXTENSION.getServiceRegistry()
- .addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor));
-
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount));
when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount));
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
+
+ return new KeysGrpcService(accountsManager, keysManager, rateLimiters);
}
@Test
@@ -152,7 +139,7 @@ class KeysGrpcServiceTest {
.setPniEcPreKeyCount(3)
.setPniKemPreKeyCount(4)
.build(),
- keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
+ authenticatedServiceStub().getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
}
@ParameterizedTest
@@ -168,7 +155,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
- keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
+ authenticatedServiceStub().setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
.setIdentityType(identityType)
.addAllPreKeys(preKeys.stream()
.map(preKey -> EcPreKey.newBuilder()
@@ -189,10 +176,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) {
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request));
-
- assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeEcPreKeys(request));
}
private static Stream setOneTimeEcPreKeysWithError() {
@@ -242,7 +226,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
- keysStub.setOneTimeKemSignedPreKeys(
+ authenticatedServiceStub().setOneTimeKemSignedPreKeys(
SetOneTimeKemSignedPreKeysRequest.newBuilder()
.setIdentityType(identityType)
.addAllPreKeys(preKeys.stream()
@@ -265,10 +249,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) {
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request));
-
- assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeKemSignedPreKeys(request));
}
private static Stream setOneTimeKemSignedPreKeysWithError() {
@@ -333,7 +314,7 @@ class KeysGrpcServiceTest {
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair);
//noinspection ResultOfMethodCallIgnored
- keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
+ authenticatedServiceStub().setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
.setIdentityType(identityType)
.setSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(signedPreKey.keyId())
@@ -359,7 +340,7 @@ class KeysGrpcServiceTest {
@MethodSource
void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) {
final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request));
+ assertThrows(StatusRuntimeException.class, () -> authenticatedServiceStub().setEcSignedPreKey(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
}
@@ -416,7 +397,7 @@ class KeysGrpcServiceTest {
final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair);
//noinspection ResultOfMethodCallIgnored
- keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
+ authenticatedServiceStub().setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
.setIdentityType(identityType)
.setSignedPreKey(KemSignedPreKey.newBuilder()
.setKeyId(lastResortPreKey.keyId())
@@ -437,10 +418,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) {
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request));
-
- assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setKemLastResortPreKey(request));
}
private static Stream setLastResortPreKeyWithError() {
@@ -528,7 +506,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
{
- final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
+ final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier))
@@ -563,7 +541,7 @@ class KeysGrpcServiceTest {
when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
{
- final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
+ final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier))
@@ -606,15 +584,12 @@ class KeysGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
- .setTargetIdentifier(ServiceIdentifier.newBuilder()
- .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
- .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
- .build())
- .build()));
-
- assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
+ .build())
+ .build()));
}
@ParameterizedTest
@@ -631,16 +606,13 @@ class KeysGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
- .setTargetIdentifier(ServiceIdentifier.newBuilder()
- .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
- .setUuid(UUIDUtil.toByteString(accountIdentifier))
- .build())
- .setDeviceId(deviceId)
- .build()));
-
- assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(accountIdentifier))
+ .build())
+ .setDeviceId(deviceId)
+ .build()));
}
@Test
@@ -655,22 +627,15 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final Duration retryAfterDuration = Duration.ofMinutes(7);
-
when(preKeysRateLimiter.validateReactive(anyString()))
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
- .setTargetIdentifier(ServiceIdentifier.newBuilder()
- .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
- .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
- .build())
- .build()));
-
- assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
- assertNotNull(exception.getTrailers());
- assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
-
+ assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
+ .build())
+ .build()));
verifyNoInteractions(accountsManager);
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java
index c129ac8df..65ef88fe9 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java
@@ -1,19 +1,27 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
package org.whispersystems.textsecuregcm.grpc;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
+import io.grpc.Channel;
import io.grpc.Metadata;
import io.grpc.Status;
+import io.grpc.stub.MetadataUtils;
+import java.lang.reflect.InvocationTargetException;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Instant;
@@ -24,14 +32,12 @@ import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
-import io.grpc.StatusRuntimeException;
-import io.grpc.stub.MetadataUtils;
-import org.junit.jupiter.api.BeforeEach;
+import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
+import org.mockito.Mock;
import org.signal.chat.common.IdentityType;
import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.profile.CredentialType;
@@ -72,43 +78,41 @@ import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
-import javax.annotation.Nullable;
-public class ProfileAnonymousGrpcServiceTest {
+public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest {
+
+ @Mock
private Account account;
+
+ @Mock
private AccountsManager accountsManager;
+
+ @Mock
private ProfilesManager profilesManager;
+
+ @Mock
private ProfileBadgeConverter profileBadgeConverter;
- private ProfileAnonymousGrpc.ProfileAnonymousBlockingStub profileAnonymousBlockingStub;
+
+ @Mock
private ServerZkProfileOperations serverZkProfileOperations;
- @RegisterExtension
- static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
-
- @BeforeEach
- void setup() {
- account = mock(Account.class);
- accountsManager = mock(AccountsManager.class);
- profilesManager = mock(ProfilesManager.class);
- profileBadgeConverter = mock(ProfileBadgeConverter.class);
- serverZkProfileOperations = mock(ServerZkProfileOperations.class);
-
- final Metadata metadata = new Metadata();
- metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
- metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
-
- profileAnonymousBlockingStub = ProfileAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
- .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
-
- final ProfileAnonymousGrpcService profileAnonymousGrpcService = new ProfileAnonymousGrpcService(
+
+ @Override
+ protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
+ return new ProfileAnonymousGrpcService(
accountsManager,
profilesManager,
profileBadgeConverter,
serverZkProfileOperations
);
+ }
- GRPC_SERVER_EXTENSION.getServiceRegistry()
- .addService(profileAnonymousGrpcService);
+ @Override
+ protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
+ final Metadata metadata = new Metadata();
+ metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
+ metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
+ return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}
@Test
@@ -151,7 +155,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
- final GetUnversionedProfileResponse response = profileAnonymousBlockingStub.getUnversionedProfile(request);
+ final GetUnversionedProfileResponse response = unauthenticatedServiceStub().getUnversionedProfile(request);
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
final GetUnversionedProfileResponse expectedResponse = GetUnversionedProfileResponse.newBuilder()
@@ -189,10 +193,7 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
}
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileAnonymousBlockingStub.getUnversionedProfile(requestBuilder.build()));
-
- assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getUnversionedProfile(requestBuilder.build()));
}
private static Stream getUnversionedProfileUnauthenticated() {
@@ -242,7 +243,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
- final GetVersionedProfileResponse response = profileAnonymousBlockingStub.getVersionedProfile(request);
+ final GetVersionedProfileResponse response = unauthenticatedServiceStub().getVersionedProfile(request);
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
.setName(ByteString.copyFrom(name))
@@ -287,10 +288,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileAnonymousBlockingStub.getVersionedProfile(request));
-
- assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getVersionedProfile(request));
}
@ParameterizedTest
@@ -318,18 +316,15 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
}
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileAnonymousBlockingStub.getVersionedProfile(requestBuilder.build()));
-
- assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getVersionedProfile(requestBuilder.build()));
}
-
private static Stream getVersionedProfileUnauthenticated() {
return Stream.of(
Arguments.of(true, false),
Arguments.of(false, true)
);
}
+
@Test
void getVersionedProfilePniInvalidArgument() {
final byte[] unidentifiedAccessKey = new byte[16];
@@ -346,10 +341,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileAnonymousBlockingStub.getVersionedProfile(request));
-
- assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getVersionedProfile(request));
}
@Test
@@ -404,7 +396,7 @@ public class ProfileAnonymousGrpcServiceTest {
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.build();
- final GetExpiringProfileKeyCredentialResponse response = profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request);
+ final GetExpiringProfileKeyCredentialResponse response = unauthenticatedServiceStub().getExpiringProfileKeyCredential(request);
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
@@ -442,10 +434,7 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
}
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(requestBuilder.build()));
-
- assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(requestBuilder.build()));
verifyNoInteractions(profilesManager);
}
@@ -483,10 +472,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
-
- assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
@ParameterizedTest
@@ -521,10 +507,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
-
- assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
private static Stream getExpiringProfileKeyCredentialInvalidArgument() {
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java
index a773c3b60..bf570284b 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java
@@ -10,8 +10,6 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
@@ -21,15 +19,14 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
+import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString;
import io.grpc.Metadata;
-import io.grpc.ServerInterceptors;
import io.grpc.Status;
-import io.grpc.StatusRuntimeException;
-import io.grpc.stub.MetadataUtils;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Clock;
@@ -44,15 +41,14 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import javax.annotation.Nullable;
-import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
import org.signal.chat.common.IdentityType;
import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.profile.CredentialType;
@@ -82,7 +78,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest;
import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext;
import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum;
-import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.badges.ProfileBadgeConverter;
import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
@@ -100,49 +95,54 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
-import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
+import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
-public class ProfileGrpcServiceTest {
- private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
- private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
+public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest {
+
private static final String S3_BUCKET = "profileBucket";
+
private static final String VERSION = "someVersion";
+
private static final byte[] VALID_NAME = new byte[81];
+
+ @Mock
private AccountsManager accountsManager;
+
+ @Mock
private ProfilesManager profilesManager;
+
+ @Mock
private DynamicPaymentsConfiguration dynamicPaymentsConfiguration;
+
+ @Mock
private S3AsyncClient asyncS3client;
+
+ @Mock
private VersionedProfile profile;
+
+ @Mock
private Account account;
+
+ @Mock
private RateLimiter rateLimiter;
+
+ @Mock
private ProfileBadgeConverter profileBadgeConverter;
+
+ @Mock
private ServerZkProfileOperations serverZkProfileOperations;
- private ProfileGrpc.ProfileBlockingStub profileBlockingStub;
-
- @RegisterExtension
- static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
-
- @BeforeEach
- void setup() {
- accountsManager = mock(AccountsManager.class);
- profilesManager = mock(ProfilesManager.class);
- dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class);
- asyncS3client = mock(S3AsyncClient.class);
- profile = mock(VersionedProfile.class);
- account = mock(Account.class);
- rateLimiter = mock(RateLimiter.class);
- profileBadgeConverter = mock(ProfileBadgeConverter.class);
- serverZkProfileOperations = mock(ServerZkProfileOperations.class);
+ @Override
+ protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
final PolicySigner policySigner = new PolicySigner("accessSecret", "us-west-1");
@@ -170,30 +170,6 @@ public class ProfileGrpcServiceTest {
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
- profileBlockingStub = ProfileGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
- .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
-
- final ProfileGrpcService profileGrpcService = new ProfileGrpcService(
- Clock.systemUTC(),
- accountsManager,
- profilesManager,
- dynamicConfigurationManager,
- badgesConfiguration,
- asyncS3client,
- policyGenerator,
- policySigner,
- profileBadgeConverter,
- rateLimiters,
- serverZkProfileOperations,
- S3_BUCKET
- );
-
- final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
- mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
-
- GRPC_SERVER_EXTENSION.getServiceRegistry()
- .addService(ServerInterceptors.intercept(profileGrpcService, mockAuthenticationInterceptor));
-
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
@@ -218,6 +194,21 @@ public class ProfileGrpcServiceTest {
when(dynamicPaymentsConfiguration.getDisallowedPrefixes()).thenReturn(Collections.emptyList());
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
+
+ return new ProfileGrpcService(
+ Clock.systemUTC(),
+ accountsManager,
+ profilesManager,
+ dynamicConfigurationManager,
+ badgesConfiguration,
+ asyncS3client,
+ policyGenerator,
+ policySigner,
+ profileBadgeConverter,
+ rateLimiters,
+ serverZkProfileOperations,
+ S3_BUCKET
+ );
}
@Test
@@ -237,7 +228,7 @@ public class ProfileGrpcServiceTest {
.setCommitment(ByteString.copyFrom(commitment))
.build();
- profileBlockingStub.setProfile(request);
+ authenticatedServiceStub().setProfile(request);
final ArgumentCaptor profileArgumentCaptor = ArgumentCaptor.forClass(VersionedProfile.class);
@@ -274,7 +265,7 @@ public class ProfileGrpcServiceTest {
hasPreviousProfile ? Optional.of(profile) : Optional.empty()));
when(profilesManager.setAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
- SetProfileResponse response = profileBlockingStub.setProfile(request);
+ final SetProfileResponse response = authenticatedServiceStub().setProfile(request);
if (expectHasS3UploadPath) {
assertTrue(response.getAttributes().getPath().startsWith("profiles/"));
@@ -312,10 +303,7 @@ public class ProfileGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setProfileInvalidRequestData(final SetProfileRequest request) {
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
-
- assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setProfile(request));
}
private static Stream setProfileInvalidRequestData() throws InvalidInputException{
@@ -386,12 +374,10 @@ public class ProfileGrpcServiceTest {
when(profilesManager.getAsync(any(), anyString())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
if (hasExistingPaymentAddress) {
- assertDoesNotThrow(() -> profileBlockingStub.setProfile(request),
+ assertDoesNotThrow(() -> authenticatedServiceStub().setProfile(request),
"Payment address changes in disallowed countries should still be allowed if the account already has a valid payment address");
} else {
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
- assertEquals(Status.PERMISSION_DENIED.getCode(), exception.getStatus().getCode());
+ assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().setProfile(request));
}
}
@@ -433,7 +419,7 @@ public class ProfileGrpcServiceTest {
when(profileBadgeConverter.convert(any(), any(), anyBoolean())).thenReturn(badges);
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
- final GetUnversionedProfileResponse response = profileBlockingStub.getUnversionedProfile(request);
+ final GetUnversionedProfileResponse response = authenticatedServiceStub().getUnversionedProfile(request);
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
final GetUnversionedProfileResponse prototypeExpectedResponse = GetUnversionedProfileResponse.newBuilder()
@@ -472,10 +458,7 @@ public class ProfileGrpcServiceTest {
.build())
.build();
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileBlockingStub.getUnversionedProfile(request));
-
- assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getUnversionedProfile(request));
}
@ParameterizedTest
@@ -493,14 +476,7 @@ public class ProfileGrpcServiceTest {
.build())
.build();
- final StatusRuntimeException exception =
- assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.getUnversionedProfile(request));
-
- assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
- assertNotNull(exception.getTrailers());
- assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
-
- verifyNoInteractions(accountsManager);
+ assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getUnversionedProfile(request), accountsManager);
}
@ParameterizedTest
@@ -531,7 +507,7 @@ public class ProfileGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
- final GetVersionedProfileResponse response = profileBlockingStub.getVersionedProfile(request);
+ final GetVersionedProfileResponse response = authenticatedServiceStub().getVersionedProfile(request);
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
.setName(ByteString.copyFrom(name))
@@ -545,7 +521,6 @@ public class ProfileGrpcServiceTest {
assertEquals(expectedResponseBuilder.build(), response);
}
-
private static Stream getVersionedProfile() {
return Stream.of(
Arguments.of("version1", "version1", true),
@@ -553,6 +528,7 @@ public class ProfileGrpcServiceTest {
Arguments.of("version1", "version2", false)
);
}
+
@ParameterizedTest
@MethodSource
void getVersionedProfileAccountOrProfileNotFound(final boolean missingAccount, final boolean missingProfile) {
@@ -566,10 +542,7 @@ public class ProfileGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(missingAccount ? Optional.empty() : Optional.of(account)));
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(missingProfile ? Optional.empty() : Optional.of(profile)));
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileBlockingStub.getVersionedProfile(request));
-
- assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getVersionedProfile(request));
}
private static Stream getVersionedProfileAccountOrProfileNotFound() {
@@ -581,10 +554,7 @@ public class ProfileGrpcServiceTest {
@Test
void getVersionedProfileRatelimited() {
- final Duration retryAfterDuration = Duration.ofMinutes(7);
-
- when(rateLimiter.validateReactive(any(UUID.class)))
- .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
+ final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(7), false);
final GetVersionedProfileRequest request = GetVersionedProfileRequest.newBuilder()
.setAccountIdentifier(ServiceIdentifier.newBuilder()
@@ -594,15 +564,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
- final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
- () -> profileBlockingStub.getVersionedProfile(request));
-
- assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
- assertNotNull(exception.getTrailers());
- assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
-
- verifyNoInteractions(accountsManager);
- verifyNoInteractions(profilesManager);
+ assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getVersionedProfile(request), accountsManager, profilesManager);
}
@Test
@@ -615,9 +577,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
- final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
- () -> profileBlockingStub.getVersionedProfile(request));
- assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getVersionedProfile(request));
}
@Test
@@ -664,7 +624,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
- final GetExpiringProfileKeyCredentialResponse response = profileBlockingStub.getExpiringProfileKeyCredential(request);
+ final GetExpiringProfileKeyCredentialResponse response = authenticatedServiceStub().getExpiringProfileKeyCredential(request);
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
@@ -677,9 +637,8 @@ public class ProfileGrpcServiceTest {
@Test
void getExpiringProfileKeyCredentialRateLimited() {
- final Duration retryAfterDuration = Duration.ofMinutes(5);
- when(rateLimiter.validateReactive(AUTHENTICATED_ACI))
- .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
+ final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(
+ rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(5), false);
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
final GetExpiringProfileKeyCredentialRequest request = GetExpiringProfileKeyCredentialRequest.newBuilder()
@@ -692,14 +651,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
- StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
- () -> profileBlockingStub.getExpiringProfileKeyCredential(request));
-
- assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
- assertNotNull(exception.getTrailers());
- assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
-
- verifyNoInteractions(profilesManager);
+ assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request), profilesManager);
}
@ParameterizedTest
@@ -723,10 +675,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
- final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
- () -> profileBlockingStub.getExpiringProfileKeyCredential(request));
-
- assertEquals(Status.Code.NOT_FOUND, statusRuntimeException.getStatus().getCode());
+ assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
private static Stream getExpiringProfileKeyCredentialAccountOrProfileNotFound() {
@@ -761,10 +710,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
- StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
- () -> profileBlockingStub.getExpiringProfileKeyCredential(request));
-
- assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
+ assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
private static Stream getExpiringProfileKeyCredentialInvalidArgument() {
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java
new file mode 100644
index 000000000..8d3907022
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import static java.util.Objects.requireNonNull;
+
+import io.grpc.BindableService;
+import io.grpc.Channel;
+import io.grpc.stub.AbstractBlockingStub;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.UUID;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.mockito.MockitoAnnotations;
+import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
+import org.whispersystems.textsecuregcm.storage.Device;
+
+/**
+ * Base class for the common case of gRPC services tests. This base class makes some assumptions
+ * and introduces some constraints on the implementing classes with a goal of simplifying the process
+ * of creating a test for the most of the gRPC services.
+ *
+ * @param Class of the gRPC service that is being tested.
+ * @param Class of the gRPC service stub.
+ */
+public abstract class SimpleBaseGrpcTest> {
+
+ @RegisterExtension
+ protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_AUTHENTICATED = new GrpcServerExtension();
+
+ @RegisterExtension
+ protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_UNAUTHENTICATED = new GrpcServerExtension();
+
+ protected static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
+
+ protected static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
+
+ private AutoCloseable mocksCloseable;
+
+ private MockAuthenticationInterceptor mockAuthenticationInterceptor;
+
+ private SERVICE service;
+
+ private STUB authenticatedServiceStub;
+
+ private STUB unauthenticatedServiceStub;
+
+
+ /**
+ * This method is invoked before each test and is expected to create an instance of the gRPC service
+ * that is being tested and also to perform all necessary before-each setup.
+ *
+ * Extending classes may have their own {@code @BeforeEach} method, but it will be called after this method.
+ * @return an instance of the gRPC service.
+ */
+ protected abstract SERVICE createServiceBeforeEachTest();
+
+ /**
+ * The default implementation of this method is based on figuring out the name of the {@code `*Grpc`} class
+ * and invoking {@code `*Grpc.newBlockingStub()`} method with reflection.
+ *
+ * Overriding this method can be helpful if addutional configuration of the stub is required, e.g. adding interceptors:
+ *
+ * protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
+ * final Metadata metadata = new Metadata();
+ * metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
+ * metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
+ * return super.createStub(channel)
+ * .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
+ * }
+ *
+ * @param channel grpc channel to create create the stub for.
+ * @return and instance of the service stub.
+ */
+ protected STUB createStub(final Channel channel) throws
+ ClassNotFoundException,
+ NoSuchMethodException,
+ InvocationTargetException,
+ IllegalAccessException {
+ final String serviceClassName = service.bindService().getServiceDescriptor().getName();
+ final String grpcClassName = serviceClassName + "Grpc";
+ final Class> grpcClass = ClassLoader.getSystemClassLoader().loadClass(grpcClassName);
+ final Method newBlockingStubMethod = grpcClass.getMethod("newBlockingStub", Channel.class);
+ final Object stub = newBlockingStubMethod.invoke(null, channel);
+ //noinspection unchecked
+ return (STUB) stub;
+ }
+
+ @BeforeEach
+ protected void baseSetup() {
+ mocksCloseable = MockitoAnnotations.openMocks(this);
+ service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`");
+ mockAuthenticationInterceptor = GrpcTestUtils.setupAuthenticatedExtension(
+ GRPC_SERVER_EXTENSION_AUTHENTICATED, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
+ GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, service);
+ try {
+ authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
+ unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel());
+ } catch (Exception e) {
+ throw new RuntimeException("Could not create a stub based on the service name. Try overriding `createStub()` method.");
+ }
+ }
+
+ @AfterEach
+ public void releaseMocks() throws Exception {
+ mocksCloseable.close();
+ }
+
+ public MockAuthenticationInterceptor mockAuthenticationInterceptor() {
+ return mockAuthenticationInterceptor;
+ }
+
+ protected SERVICE service() {
+ return service;
+ }
+
+ protected STUB authenticatedServiceStub() {
+ return authenticatedServiceStub;
+ }
+
+ protected STUB unauthenticatedServiceStub() {
+ return unauthenticatedServiceStub;
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java
index 3173f8753..9d90e2405 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java
@@ -111,7 +111,7 @@ public class RateLimitersTest {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
- assertEquals(expected, limiter.config());
+ assertEquals(expected, config(limiter));
}
@Test
@@ -131,16 +131,16 @@ public class RateLimitersTest {
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
- assertEquals(initialRateLimiterConfig, limiter.config());
+ assertEquals(initialRateLimiterConfig, config(limiter));
- assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
- assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
+ assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
+ assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
- assertEquals(updatedRateLimiterCongig, limiter.config());
+ assertEquals(updatedRateLimiterCongig, config(limiter));
- assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
- assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
+ assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
+ assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
}
@Test
@@ -161,22 +161,22 @@ public class RateLimitersTest {
// test only default is present
mapForDynamic.remove(descriptor.id());
mapForStatic.remove(descriptor.id());
- assertEquals(defaultConfig, limiter.config());
+ assertEquals(defaultConfig, config(limiter));
// test dynamic and no static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.remove(descriptor.id());
- assertEquals(configForDynamic, limiter.config());
+ assertEquals(configForDynamic, config(limiter));
// test dynamic and static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.put(descriptor.id(), configForStatic);
- assertEquals(configForDynamic, limiter.config());
+ assertEquals(configForDynamic, config(limiter));
// test static, but no dynamic
mapForDynamic.remove(descriptor.id());
mapForStatic.put(descriptor.id(), configForStatic);
- assertEquals(configForStatic, limiter.config());
+ assertEquals(configForStatic, config(limiter));
}
private record TestDescriptor(String id) implements RateLimiterDescriptor {
@@ -191,4 +191,14 @@ public class RateLimitersTest {
return new RateLimiterConfig(1, Duration.ofMinutes(1));
}
}
+
+ private static RateLimiterConfig config(final RateLimiter rateLimiter) {
+ if (rateLimiter instanceof StaticRateLimiter rm) {
+ return rm.config();
+ }
+ if (rateLimiter instanceof DynamicRateLimiter rm) {
+ return rm.config();
+ }
+ throw new IllegalArgumentException("Rate limiter is of an unexpected type: " + rateLimiter.getClass().getName());
+ }
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java
index 16466b8fd..af5a62b3b 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java
@@ -12,12 +12,14 @@ import static org.mockito.Mockito.doThrow;
import java.time.Duration;
import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.RandomUtils;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
+import reactor.core.publisher.Mono;
public final class MockUtils {
@@ -46,32 +48,80 @@ public final class MockUtils {
}
public static void updateRateLimiterResponseToAllow(
- final RateLimiters rateLimitersMock,
- final RateLimiters.For handle,
+ final RateLimiter mockRateLimiter,
final String input) {
- final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
- doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doNothing().when(mockRateLimiter).validate(eq(input));
+ doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
+ doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
+ } catch (final RateLimitExceededException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static void updateRateLimiterResponseToAllow(
+ final RateLimiter mockRateLimiter,
+ final UUID input) {
+ try {
+ doNothing().when(mockRateLimiter).validate(eq(input));
+ doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
+ doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
+ public static void updateRateLimiterResponseToAllow(
+ final RateLimiters rateLimitersMock,
+ final RateLimiters.For handle,
+ final String input) {
+ final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
+ doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
+ updateRateLimiterResponseToAllow(mockRateLimiter, input);
+ }
+
public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final UUID input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
+ updateRateLimiterResponseToAllow(mockRateLimiter, input);
+ }
+
+ public static Duration updateRateLimiterResponseToFail(
+ final RateLimiter mockRateLimiter,
+ final String input,
+ final Duration retryAfter,
+ final boolean legacyStatusCode) {
try {
- doNothing().when(mockRateLimiter).validate(eq(input));
+ final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
+ doThrow(exception).when(mockRateLimiter).validate(eq(input));
+ doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
+ doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
+ return retryAfter;
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
- public static void updateRateLimiterResponseToFail(
+ public static Duration updateRateLimiterResponseToFail(
+ final RateLimiter mockRateLimiter,
+ final UUID input,
+ final Duration retryAfter,
+ final boolean legacyStatusCode) {
+ try {
+ final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
+ doThrow(exception).when(mockRateLimiter).validate(eq(input));
+ doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
+ doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
+ return retryAfter;
+ } catch (final RateLimitExceededException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static Duration updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final String input,
@@ -79,14 +129,10 @@ public final class MockUtils {
final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
- try {
- doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
- } catch (final RateLimitExceededException e) {
- throw new RuntimeException(e);
- }
+ return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
}
- public static void updateRateLimiterResponseToFail(
+ public static Duration updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final UUID input,
@@ -94,11 +140,7 @@ public final class MockUtils {
final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
- try {
- doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
- } catch (final RateLimitExceededException e) {
- throw new RuntimeException(e);
- }
+ return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
}
public static SecretBytes randomSecretBytes(final int size) {