diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 080d03819..1b7185a8f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -120,6 +120,7 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor; +import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor; import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper; import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.KeysGrpcService; @@ -644,8 +645,6 @@ public class WhisperServerService extends Application grpcServer = ServerBuilder.forPort(config.getGrpcPort()) - // TODO: specialize metrics with user-agent platform - .intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry)) .addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor)) .addService(new KeysAnonymousGrpcService(accountsManager, keys)) .addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager, @@ -657,13 +656,16 @@ public class WhisperServerService extends Application RETRY_AFTER_DURATION_KEY = + Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() { + @Override + public String toAsciiString(final Duration value) { + return value.toString(); + } + + @Override + public Duration parseAsciiString(final String serialized) { + return Duration.parse(serialized); + } + }); @Nullable private final Duration retryDuration; @@ -33,4 +49,19 @@ public class RateLimitExceededException extends Exception { public boolean isLegacy() { return legacy; } + + @Override + public Status grpcStatus() { + return Status.RESOURCE_EXHAUSTED; + } + + @Override + public Optional grpcMetadata() { + return getRetryDuration() + .map(duration -> { + final Metadata metadata = new Metadata(); + metadata.put(RETRY_AFTER_DURATION_KEY, duration); + return metadata; + }); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java index f981f094d..7acab7547 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java @@ -24,11 +24,6 @@ public class CallingGrpcService extends ReactorCallingGrpc.CallingImplBase { this.rateLimiters = rateLimiters; } - @Override - protected Throwable onErrorMap(final Throwable throwable) { - return RateLimitUtil.mapRateLimitExceededException(throwable); - } - @Override public Mono getTurnCredentials(final GetTurnCredentialsRequest request) { final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ConvertibleToGrpcStatus.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ConvertibleToGrpcStatus.java new file mode 100644 index 000000000..ce0cc00fb --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ConvertibleToGrpcStatus.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Metadata; +import io.grpc.Status; +import java.util.Optional; + +/** + * Interface to be imlemented by our custom exceptions that are consistently mapped to a gRPC status. + */ +public interface ConvertibleToGrpcStatus { + + Status grpcStatus(); + + Optional grpcMetadata(); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ErrorMappingInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ErrorMappingInterceptor.java new file mode 100644 index 000000000..75d25d498 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ErrorMappingInterceptor.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.ForwardingServerCall; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; + +/** + * This interceptor observes responses from the service and if the response status is {@link Status#UNKNOWN} + * and there is a non-null cause which is an instance of {@link ConvertibleToGrpcStatus}, + * then status and metadata to be returned to the client is resolved from that object. + *

+ * This eliminates the need of having each service to override {@code `onErrorMap()`} method for commonly used exceptions. + */ +public class ErrorMappingInterceptor implements ServerInterceptor { + + @Override + public ServerCall.Listener interceptCall( + final ServerCall call, + final Metadata headers, + final ServerCallHandler next) { + return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall<>(call) { + @Override + public void close(final Status status, final Metadata trailers) { + // The idea is to only apply the automatic conversion logic in the cases + // when there was no explicit decision by the service to provide a status. + // I.e. if at this point we see anything but the `UNKNOWN`, + // that means that some logic in the service made this decision already + // and automatic conversion may conflict with it. + if (status.getCode().equals(Status.Code.UNKNOWN) + && status.getCause() instanceof ConvertibleToGrpcStatus convertibleToGrpcStatus) { + super.close( + convertibleToGrpcStatus.grpcStatus(), + convertibleToGrpcStatus.grpcMetadata().orElseGet(Metadata::new) + ); + } else { + super.close(status, trailers); + } + } + }, headers); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java index 9ced01fbf..6d3c3ccf9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java @@ -74,11 +74,6 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { this.rateLimiters = rateLimiters; } - @Override - protected Throwable onErrorMap(final Throwable throwable) { - return RateLimitUtil.mapRateLimitExceededException(throwable); - } - @Override public Mono getPreKeyCount(final GetPreKeyCountRequest request) { return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java index e9dd35e46..7c26a4582 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcService.java @@ -100,11 +100,6 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase { this.bucket = bucket; } - @Override - protected Throwable onErrorMap(final Throwable throwable) { - return RateLimitUtil.mapRateLimitExceededException(throwable); - } - @Override public Mono setProfile(final SetProfileRequest request) { validateRequest(request); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java deleted file mode 100644 index dadfd5417..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import io.grpc.Metadata; -import io.grpc.Status; -import io.grpc.StatusException; -import java.time.Duration; -import javax.annotation.Nullable; -import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; - -public class RateLimitUtil { - - public static final Metadata.Key RETRY_AFTER_DURATION_KEY = - Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() { - @Override - public String toAsciiString(final Duration value) { - return value.toString(); - } - - @Override - public Duration parseAsciiString(final String serialized) { - return Duration.parse(serialized); - } - }); - - public static Throwable mapRateLimitExceededException(final Throwable throwable) { - if (throwable instanceof RateLimitExceededException rateLimitExceededException) { - @Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration() - .map(duration -> { - final Metadata metadata = new Metadata(); - metadata.put(RETRY_AFTER_DURATION_KEY, duration); - - return metadata; - }).orElse(null); - - return new StatusException(Status.RESOURCE_EXHAUSTED, trailers); - } - - return throwable; - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java index 0753dfadf..3f75c7eb9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java @@ -6,66 +6,41 @@ package org.whispersystems.textsecuregcm.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; -import io.grpc.ServerInterceptors; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; import java.time.Duration; import java.util.List; -import java.util.UUID; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; +import org.mockito.Mock; import org.signal.chat.calling.CallingGrpc; import org.signal.chat.calling.GetTurnCredentialsRequest; import org.signal.chat.calling.GetTurnCredentialsResponse; import org.whispersystems.textsecuregcm.auth.TurnToken; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; -import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; -import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.storage.Device; -import reactor.core.publisher.Mono; +import org.whispersystems.textsecuregcm.util.MockUtils; -class CallingGrpcServiceTest { +class CallingGrpcServiceTest extends SimpleBaseGrpcTest { + @Mock private TurnTokenGenerator turnTokenGenerator; + + @Mock private RateLimiter turnCredentialRateLimiter; - private CallingGrpc.CallingBlockingStub callingStub; - - private static final UUID AUTHENTICATED_ACI = UUID.randomUUID(); - private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID; - - @RegisterExtension - static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); - - @BeforeEach - void setUp() { - turnTokenGenerator = mock(TurnTokenGenerator.class); - turnCredentialRateLimiter = mock(RateLimiter.class); + @Override + protected CallingGrpcService createServiceBeforeEachTest() { final RateLimiters rateLimiters = mock(RateLimiters.class); when(rateLimiters.getTurnLimiter()).thenReturn(turnCredentialRateLimiter); - - final CallingGrpcService callingGrpcService = new CallingGrpcService(turnTokenGenerator, rateLimiters); - - final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID); - - GRPC_SERVER_EXTENSION.getServiceRegistry() - .addService(ServerInterceptors.intercept(callingGrpcService, mockAuthenticationInterceptor)); - - callingStub = CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()); + return new CallingGrpcService(turnTokenGenerator, rateLimiters); } @Test @@ -74,10 +49,10 @@ class CallingGrpcServiceTest { final String password = "test-password"; final List urls = List.of("first", "second"); - when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)).thenReturn(Mono.empty()); + MockUtils.updateRateLimiterResponseToAllow(turnCredentialRateLimiter, AUTHENTICATED_ACI); when(turnTokenGenerator.generate(any())).thenReturn(new TurnToken(username, password, urls)); - final GetTurnCredentialsResponse response = callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()); + final GetTurnCredentialsResponse response = authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()); final GetTurnCredentialsResponse expectedResponse = GetTurnCredentialsResponse.newBuilder() .setUsername(username) @@ -90,20 +65,10 @@ class CallingGrpcServiceTest { @Test void getTurnCredentialsRateLimited() { - final Duration retryAfter = Duration.ofMinutes(19); - - when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)) - .thenReturn(Mono.error(new RateLimitExceededException(retryAfter, false))); - - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build())); - + final Duration retryAfter = MockUtils.updateRateLimiterResponseToFail( + turnCredentialRateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(19), false); + assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build())); verify(turnTokenGenerator, never()).generate(any()); - - assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode()); - assertNotNull(exception.getTrailers()); - assertEquals(retryAfter, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY)); - verifyNoInteractions(turnTokenGenerator); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java index 4effc71f5..b5a5e673c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java @@ -6,7 +6,6 @@ package org.whispersystems.textsecuregcm.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.mock; @@ -14,11 +13,10 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import com.google.protobuf.ByteString; -import io.grpc.ServerInterceptors; import io.grpc.Status; -import io.grpc.StatusRuntimeException; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; @@ -26,21 +24,19 @@ import java.time.temporal.ChronoUnit; import java.util.Base64; import java.util.List; import java.util.Optional; -import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Consumer; import java.util.stream.Stream; import javax.annotation.Nullable; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.mockito.Mock; import org.signal.chat.device.ClearPushTokenRequest; import org.signal.chat.device.ClearPushTokenResponse; import org.signal.chat.device.DevicesGrpc; @@ -54,42 +50,31 @@ import org.signal.chat.device.SetDeviceNameRequest; import org.signal.chat.device.SetDeviceNameResponse; import org.signal.chat.device.SetPushTokenRequest; import org.signal.chat.device.SetPushTokenResponse; -import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; -class DevicesGrpcServiceTest { +class DevicesGrpcServiceTest extends SimpleBaseGrpcTest { + @Mock private AccountsManager accountsManager; + + @Mock private KeysManager keysManager; + + @Mock private MessagesManager messagesManager; + @Mock private Account authenticatedAccount; - private MockAuthenticationInterceptor mockAuthenticationInterceptor; - private DevicesGrpc.DevicesBlockingStub devicesStub; - @RegisterExtension - static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); - - private static final UUID AUTHENTICATED_ACI = UUID.randomUUID(); - private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID; - - @BeforeEach - void setUp() { - accountsManager = mock(AccountsManager.class); - keysManager = mock(KeysManager.class); - messagesManager = mock(MessagesManager.class); - - authenticatedAccount = mock(Account.class); + @Override + protected DevicesGrpcService createServiceBeforeEachTest() { when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI); - mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID); - when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)) .thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); @@ -117,11 +102,7 @@ class DevicesGrpcServiceTest { when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null)); - final DevicesGrpcService devicesGrpcService = new DevicesGrpcService(accountsManager, keysManager, messagesManager); - devicesStub = DevicesGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()); - - GRPC_SERVER_EXTENSION.getServiceRegistry() - .addService(ServerInterceptors.intercept(devicesGrpcService, mockAuthenticationInterceptor)); + return new DevicesGrpcService(accountsManager, keysManager, messagesManager); } @Test @@ -161,14 +142,14 @@ class DevicesGrpcServiceTest { .build()) .build(); - assertEquals(expectedResponse, devicesStub.getDevices(GetDevicesRequest.newBuilder().build())); + assertEquals(expectedResponse, authenticatedServiceStub().getDevices(GetDevicesRequest.newBuilder().build())); } @Test void removeDevice() { final long deviceId = 17; - final RemoveDeviceResponse ignored = devicesStub.removeDevice(RemoveDeviceRequest.newBuilder() + final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder() .setId(deviceId) .build()); @@ -179,30 +160,23 @@ class DevicesGrpcServiceTest { @Test void removeDevicePrimary() { - final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, - () -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder() - .setId(1) - .build())); - - assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder() + .setId(1) + .build())); } @Test void removeDeviceNonPrimaryAuthenticated() { - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1); - - final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, - () -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder() - .setId(17) - .build())); - - assertEquals(Status.Code.PERMISSION_DENIED, exception.getStatus().getCode()); + mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1); + assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder() + .setId(17) + .build())); } @ParameterizedTest @ValueSource(longs = {Device.MASTER_ID, Device.MASTER_ID + 1}) void setDeviceName(final long deviceId) { - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); + mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); final Device device = mock(Device.class); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); @@ -210,7 +184,7 @@ class DevicesGrpcServiceTest { final byte[] deviceName = new byte[128]; ThreadLocalRandom.current().nextBytes(deviceName); - final SetDeviceNameResponse ignored = devicesStub.setDeviceName(SetDeviceNameRequest.newBuilder() + final SetDeviceNameResponse ignored = authenticatedServiceStub().setDeviceName(SetDeviceNameRequest.newBuilder() .setName(ByteString.copyFrom(deviceName)) .build()); @@ -221,11 +195,7 @@ class DevicesGrpcServiceTest { @MethodSource void setDeviceNameIllegalArgument(final SetDeviceNameRequest request) { when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(mock(Device.class))); - - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> devicesStub.setDeviceName(request)); - - assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setDeviceName(request)); } private static Stream setDeviceNameIllegalArgument() { @@ -248,12 +218,12 @@ class DevicesGrpcServiceTest { @Nullable final String expectedApnsVoipToken, @Nullable final String expectedFcmToken) { - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); + mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); final Device device = mock(Device.class); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); - final SetPushTokenResponse ignored = devicesStub.setPushToken(request); + final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request); verify(device).setApnId(expectedApnsToken); verify(device).setVoipApnId(expectedApnsVoipToken); @@ -312,7 +282,7 @@ class DevicesGrpcServiceTest { when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device)); - final SetPushTokenResponse ignored = devicesStub.setPushToken(request); + final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request); verify(accountsManager, never()).updateDevice(any(), anyLong(), any()); } @@ -352,12 +322,7 @@ class DevicesGrpcServiceTest { void setPushTokenIllegalArgument(final SetPushTokenRequest request) { final Device device = mock(Device.class); when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device)); - - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> devicesStub.setPushToken(request)); - - assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode()); - + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setPushToken(request)); verify(accountsManager, never()).updateDevice(any(), anyLong(), any()); } @@ -383,7 +348,7 @@ class DevicesGrpcServiceTest { @Nullable final String fcmToken, @Nullable final String expectedUserAgent) { - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); + mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); final Device device = mock(Device.class); when(device.getId()).thenReturn(deviceId); @@ -393,7 +358,7 @@ class DevicesGrpcServiceTest { when(device.getGcmId()).thenReturn(fcmToken); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); - final ClearPushTokenResponse ignored = devicesStub.clearPushToken(ClearPushTokenRequest.newBuilder().build()); + final ClearPushTokenResponse ignored = authenticatedServiceStub().clearPushToken(ClearPushTokenRequest.newBuilder().build()); verify(device).setApnId(null); verify(device).setVoipApnId(null); @@ -430,12 +395,12 @@ class DevicesGrpcServiceTest { @CartesianTest.Values(booleans = {true, false}) final boolean pni, @CartesianTest.Values(booleans = {true, false}) final boolean paymentActivation) { - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); + mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); final Device device = mock(Device.class); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); - final SetCapabilitiesResponse ignored = devicesStub.setCapabilities(SetCapabilitiesRequest.newBuilder() + final SetCapabilitiesResponse ignored = authenticatedServiceStub().setCapabilities(SetCapabilitiesRequest.newBuilder() .setStorage(storage) .setTransfer(transfer) .setPni(pni) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java new file mode 100644 index 000000000..116aafc13 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcTestUtils.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.verifyNoInteractions; + +import io.grpc.BindableService; +import io.grpc.ServerInterceptors; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.function.Executable; +import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; + +public final class GrpcTestUtils { + + private GrpcTestUtils() { + // noop + } + + public static MockAuthenticationInterceptor setupAuthenticatedExtension( + final GrpcServerExtension extension, + final UUID authenticatedAci, + final long authenticatedDeviceId, + final BindableService service) { + final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); + mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId); + extension.getServiceRegistry() + .addService(ServerInterceptors.intercept(service, mockAuthenticationInterceptor, new ErrorMappingInterceptor())); + return mockAuthenticationInterceptor; + } + + public static void setupUnauthenticatedExtension( + final GrpcServerExtension extension, + final BindableService service) { + extension.getServiceRegistry() + .addService(ServerInterceptors.intercept(service, new ErrorMappingInterceptor())); + } + + public static void assertStatusException(final Status expected, final Executable serviceCall) { + final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall); + assertEquals(expected.getCode(), exception.getStatus().getCode()); + } + + public static void assertRateLimitExceeded( + final Duration expectedRetryAfter, + final Executable serviceCall, + final Object... mocksToCheckForNoInteraction) { + final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall); + assertEquals(Status.RESOURCE_EXHAUSTED, exception.getStatus()); + assertNotNull(exception.getTrailers()); + assertEquals(expectedRetryAfter, exception.getTrailers().get(RateLimitExceededException.RETRY_AFTER_DURATION_KEY)); + for (final Object mock: mocksToCheckForNoInteraction) { + verifyNoInteractions(mock); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java index 35064f3be..721081c92 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import com.google.protobuf.ByteString; import io.grpc.Status; @@ -20,11 +21,10 @@ import java.util.Collections; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcSignedPreKey; import org.signal.chat.common.KemSignedPreKey; @@ -48,27 +48,18 @@ import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.UUIDUtil; -class KeysAnonymousGrpcServiceTest { +class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest { + @Mock private AccountsManager accountsManager; + + @Mock private KeysManager keysManager; - private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub; - @RegisterExtension - static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); - - @BeforeEach - void setUp() { - accountsManager = mock(AccountsManager.class); - keysManager = mock(KeysManager.class); - - final KeysAnonymousGrpcService keysGrpcService = - new KeysAnonymousGrpcService(accountsManager, keysManager); - - keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()); - - GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService); + @Override + protected KeysAnonymousGrpcService createServiceBeforeEachTest() { + return new KeysAnonymousGrpcService(accountsManager, keysManager); } @Test @@ -101,7 +92,7 @@ class KeysAnonymousGrpcServiceTest { when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey))); when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey); - final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + final GetPreKeysResponse response = unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder() .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) .setRequest(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder() @@ -152,19 +143,15 @@ class KeysAnonymousGrpcServiceTest { when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier))) .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); - final StatusRuntimeException statusRuntimeException = - assertThrows(StatusRuntimeException.class, - () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() - .setRequest(GetPreKeysRequest.newBuilder() - .setTargetIdentifier(ServiceIdentifier.newBuilder() - .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) - .setUuid(UUIDUtil.toByteString(identifier)) - .build()) - .setDeviceId(Device.MASTER_ID) - .build()) - .build())); - - assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + .setRequest(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(identifier)) + .build()) + .setDeviceId(Device.MASTER_ID) + .build()) + .build())); } @Test @@ -174,7 +161,7 @@ class KeysAnonymousGrpcServiceTest { final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, - () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder() .setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID())) .setRequest(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder() @@ -205,19 +192,15 @@ class KeysAnonymousGrpcServiceTest { when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier))) .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, - () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() - .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) - .setRequest(GetPreKeysRequest.newBuilder() - .setTargetIdentifier(ServiceIdentifier.newBuilder() - .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) - .setUuid(UUIDUtil.toByteString(accountIdentifier)) - .build()) - .setDeviceId(deviceId) - .build()) - .build())); - - assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) + .setRequest(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(accountIdentifier)) + .build()) + .setDeviceId(deviceId) + .build()) + .build())); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java index 1ae1a443f..b0659107f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java @@ -6,7 +6,6 @@ package org.whispersystems.textsecuregcm.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; @@ -16,9 +15,10 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import com.google.protobuf.ByteString; -import io.grpc.ServerInterceptors; import io.grpc.Status; import io.grpc.StatusRuntimeException; import java.time.Duration; @@ -32,14 +32,13 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcSignedPreKey; import org.signal.chat.common.KemSignedPreKey; @@ -56,7 +55,6 @@ import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; @@ -73,41 +71,34 @@ import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.UUIDUtil; import reactor.core.publisher.Mono; -class KeysGrpcServiceTest { - - private AccountsManager accountsManager; - private KeysManager keysManager; - private RateLimiter preKeysRateLimiter; - - private Device authenticatedDevice; - - private KeysGrpc.KeysBlockingStub keysStub; - - private static final UUID AUTHENTICATED_ACI = UUID.randomUUID(); - private static final UUID AUTHENTICATED_PNI = UUID.randomUUID(); - private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID; +class KeysGrpcServiceTest extends SimpleBaseGrpcTest { private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); - @RegisterExtension - static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); + protected static final UUID AUTHENTICATED_PNI = UUID.randomUUID(); - @BeforeEach - void setUp() { - accountsManager = mock(AccountsManager.class); - keysManager = mock(KeysManager.class); - preKeysRateLimiter = mock(RateLimiter.class); + @Mock + private AccountsManager accountsManager; + @Mock + private KeysManager keysManager; + + @Mock + private RateLimiter preKeysRateLimiter; + + @Mock + private Device authenticatedDevice; + + + @Override + protected KeysGrpcService createServiceBeforeEachTest() { final RateLimiters rateLimiters = mock(RateLimiters.class); when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter); when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); - final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters); - keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()); - - authenticatedDevice = mock(Device.class); when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID); final Account authenticatedAccount = mock(Account.class); @@ -119,17 +110,13 @@ class KeysGrpcServiceTest { when(authenticatedAccount.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey())); when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice)); - final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID); - - GRPC_SERVER_EXTENSION.getServiceRegistry() - .addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor)); - when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount)); when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount)); when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); + + return new KeysGrpcService(accountsManager, keysManager, rateLimiters); } @Test @@ -152,7 +139,7 @@ class KeysGrpcServiceTest { .setPniEcPreKeyCount(3) .setPniKemPreKeyCount(4) .build(), - keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build())); + authenticatedServiceStub().getPreKeyCount(GetPreKeyCountRequest.newBuilder().build())); } @ParameterizedTest @@ -168,7 +155,7 @@ class KeysGrpcServiceTest { .thenReturn(CompletableFuture.completedFuture(null)); //noinspection ResultOfMethodCallIgnored - keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder() + authenticatedServiceStub().setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder() .setIdentityType(identityType) .addAllPreKeys(preKeys.stream() .map(preKey -> EcPreKey.newBuilder() @@ -189,10 +176,7 @@ class KeysGrpcServiceTest { @ParameterizedTest @MethodSource void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) { - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request)); - - assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeEcPreKeys(request)); } private static Stream setOneTimeEcPreKeysWithError() { @@ -242,7 +226,7 @@ class KeysGrpcServiceTest { .thenReturn(CompletableFuture.completedFuture(null)); //noinspection ResultOfMethodCallIgnored - keysStub.setOneTimeKemSignedPreKeys( + authenticatedServiceStub().setOneTimeKemSignedPreKeys( SetOneTimeKemSignedPreKeysRequest.newBuilder() .setIdentityType(identityType) .addAllPreKeys(preKeys.stream() @@ -265,10 +249,7 @@ class KeysGrpcServiceTest { @ParameterizedTest @MethodSource void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) { - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request)); - - assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeKemSignedPreKeys(request)); } private static Stream setOneTimeKemSignedPreKeysWithError() { @@ -333,7 +314,7 @@ class KeysGrpcServiceTest { final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair); //noinspection ResultOfMethodCallIgnored - keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder() + authenticatedServiceStub().setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder() .setIdentityType(identityType) .setSignedPreKey(EcSignedPreKey.newBuilder() .setKeyId(signedPreKey.keyId()) @@ -359,7 +340,7 @@ class KeysGrpcServiceTest { @MethodSource void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) { final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request)); + assertThrows(StatusRuntimeException.class, () -> authenticatedServiceStub().setEcSignedPreKey(request)); assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); } @@ -416,7 +397,7 @@ class KeysGrpcServiceTest { final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair); //noinspection ResultOfMethodCallIgnored - keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder() + authenticatedServiceStub().setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder() .setIdentityType(identityType) .setSignedPreKey(KemSignedPreKey.newBuilder() .setKeyId(lastResortPreKey.keyId()) @@ -437,10 +418,7 @@ class KeysGrpcServiceTest { @ParameterizedTest @MethodSource void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) { - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request)); - - assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setKemLastResortPreKey(request)); } private static Stream setLastResortPreKeyWithError() { @@ -528,7 +506,7 @@ class KeysGrpcServiceTest { .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); { - final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder() + final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder() .setIdentityType(grpcIdentityType) .setUuid(UUIDUtil.toByteString(identifier)) @@ -563,7 +541,7 @@ class KeysGrpcServiceTest { when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty())); { - final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder() + final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder() .setIdentityType(grpcIdentityType) .setUuid(UUIDUtil.toByteString(identifier)) @@ -606,15 +584,12 @@ class KeysGrpcServiceTest { when(accountsManager.getByServiceIdentifierAsync(any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() - .setTargetIdentifier(ServiceIdentifier.newBuilder() - .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) - .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) - .build()) - .build())); - - assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) + .build()) + .build())); } @ParameterizedTest @@ -631,16 +606,13 @@ class KeysGrpcServiceTest { when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier))) .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() - .setTargetIdentifier(ServiceIdentifier.newBuilder() - .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) - .setUuid(UUIDUtil.toByteString(accountIdentifier)) - .build()) - .setDeviceId(deviceId) - .build())); - - assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(accountIdentifier)) + .build()) + .setDeviceId(deviceId) + .build())); } @Test @@ -655,22 +627,15 @@ class KeysGrpcServiceTest { .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); final Duration retryAfterDuration = Duration.ofMinutes(7); - when(preKeysRateLimiter.validateReactive(anyString())) .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false))); - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() - .setTargetIdentifier(ServiceIdentifier.newBuilder() - .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) - .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) - .build()) - .build())); - - assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode()); - assertNotNull(exception.getTrailers()); - assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY)); - + assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) + .build()) + .build())); verifyNoInteractions(accountsManager); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java index c129ac8df..65ef88fe9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcServiceTest.java @@ -1,19 +1,27 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + package org.whispersystems.textsecuregcm.grpc; import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import com.google.protobuf.ByteString; +import io.grpc.Channel; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.stub.MetadataUtils; +import java.lang.reflect.InvocationTargetException; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; import java.time.Instant; @@ -24,14 +32,12 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; -import io.grpc.StatusRuntimeException; -import io.grpc.stub.MetadataUtils; -import org.junit.jupiter.api.BeforeEach; +import javax.annotation.Nullable; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; import org.signal.chat.common.IdentityType; import org.signal.chat.common.ServiceIdentifier; import org.signal.chat.profile.CredentialType; @@ -72,43 +78,41 @@ import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.util.UUIDUtil; -import javax.annotation.Nullable; -public class ProfileAnonymousGrpcServiceTest { +public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest { + + @Mock private Account account; + + @Mock private AccountsManager accountsManager; + + @Mock private ProfilesManager profilesManager; + + @Mock private ProfileBadgeConverter profileBadgeConverter; - private ProfileAnonymousGrpc.ProfileAnonymousBlockingStub profileAnonymousBlockingStub; + + @Mock private ServerZkProfileOperations serverZkProfileOperations; - @RegisterExtension - static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); - - @BeforeEach - void setup() { - account = mock(Account.class); - accountsManager = mock(AccountsManager.class); - profilesManager = mock(ProfilesManager.class); - profileBadgeConverter = mock(ProfileBadgeConverter.class); - serverZkProfileOperations = mock(ServerZkProfileOperations.class); - - final Metadata metadata = new Metadata(); - metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us"); - metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3"); - - profileAnonymousBlockingStub = ProfileAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); - - final ProfileAnonymousGrpcService profileAnonymousGrpcService = new ProfileAnonymousGrpcService( + + @Override + protected ProfileAnonymousGrpcService createServiceBeforeEachTest() { + return new ProfileAnonymousGrpcService( accountsManager, profilesManager, profileBadgeConverter, serverZkProfileOperations ); + } - GRPC_SERVER_EXTENSION.getServiceRegistry() - .addService(profileAnonymousGrpcService); + @Override + protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException { + final Metadata metadata = new Metadata(); + metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us"); + metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3"); + return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); } @Test @@ -151,7 +155,7 @@ public class ProfileAnonymousGrpcServiceTest { .build()) .build(); - final GetUnversionedProfileResponse response = profileAnonymousBlockingStub.getUnversionedProfile(request); + final GetUnversionedProfileResponse response = unauthenticatedServiceStub().getUnversionedProfile(request); final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey); final GetUnversionedProfileResponse expectedResponse = GetUnversionedProfileResponse.newBuilder() @@ -189,10 +193,7 @@ public class ProfileAnonymousGrpcServiceTest { requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)); } - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileAnonymousBlockingStub.getUnversionedProfile(requestBuilder.build())); - - assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getUnversionedProfile(requestBuilder.build())); } private static Stream getUnversionedProfileUnauthenticated() { @@ -242,7 +243,7 @@ public class ProfileAnonymousGrpcServiceTest { .build()) .build(); - final GetVersionedProfileResponse response = profileAnonymousBlockingStub.getVersionedProfile(request); + final GetVersionedProfileResponse response = unauthenticatedServiceStub().getVersionedProfile(request); final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder() .setName(ByteString.copyFrom(name)) @@ -287,10 +288,7 @@ public class ProfileAnonymousGrpcServiceTest { .build()) .build(); - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileAnonymousBlockingStub.getVersionedProfile(request)); - - assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getVersionedProfile(request)); } @ParameterizedTest @@ -318,18 +316,15 @@ public class ProfileAnonymousGrpcServiceTest { requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)); } - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileAnonymousBlockingStub.getVersionedProfile(requestBuilder.build())); - - assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getVersionedProfile(requestBuilder.build())); } - private static Stream getVersionedProfileUnauthenticated() { return Stream.of( Arguments.of(true, false), Arguments.of(false, true) ); } + @Test void getVersionedProfilePniInvalidArgument() { final byte[] unidentifiedAccessKey = new byte[16]; @@ -346,10 +341,7 @@ public class ProfileAnonymousGrpcServiceTest { .build()) .build(); - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileAnonymousBlockingStub.getVersionedProfile(request)); - - assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getVersionedProfile(request)); } @Test @@ -404,7 +396,7 @@ public class ProfileAnonymousGrpcServiceTest { .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) .build(); - final GetExpiringProfileKeyCredentialResponse response = profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request); + final GetExpiringProfileKeyCredentialResponse response = unauthenticatedServiceStub().getExpiringProfileKeyCredential(request); assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray()); @@ -442,10 +434,7 @@ public class ProfileAnonymousGrpcServiceTest { requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)); } - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(requestBuilder.build())); - - assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(requestBuilder.build())); verifyNoInteractions(profilesManager); } @@ -483,10 +472,7 @@ public class ProfileAnonymousGrpcServiceTest { .build()) .build(); - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request)); - - assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request)); } @ParameterizedTest @@ -521,10 +507,7 @@ public class ProfileAnonymousGrpcServiceTest { .build()) .build(); - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request)); - - assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request)); } private static Stream getExpiringProfileKeyCredentialInvalidArgument() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java index a773c3b60..bf570284b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ProfileGrpcServiceTest.java @@ -10,8 +10,6 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; @@ -21,15 +19,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.Phonenumber; import com.google.protobuf.ByteString; import io.grpc.Metadata; -import io.grpc.ServerInterceptors; import io.grpc.Status; -import io.grpc.StatusRuntimeException; -import io.grpc.stub.MetadataUtils; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; import java.time.Clock; @@ -44,15 +41,14 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import javax.annotation.Nullable; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.mockito.Mock; import org.signal.chat.common.IdentityType; import org.signal.chat.common.ServiceIdentifier; import org.signal.chat.profile.CredentialType; @@ -82,7 +78,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest; import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext; import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum; -import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; import org.whispersystems.textsecuregcm.badges.ProfileBadgeConverter; import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; @@ -100,49 +95,54 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; +import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.UUIDUtil; import reactor.core.publisher.Mono; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; -public class ProfileGrpcServiceTest { - private static final UUID AUTHENTICATED_ACI = UUID.randomUUID(); - private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID; +public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest { + private static final String S3_BUCKET = "profileBucket"; + private static final String VERSION = "someVersion"; + private static final byte[] VALID_NAME = new byte[81]; + + @Mock private AccountsManager accountsManager; + + @Mock private ProfilesManager profilesManager; + + @Mock private DynamicPaymentsConfiguration dynamicPaymentsConfiguration; + + @Mock private S3AsyncClient asyncS3client; + + @Mock private VersionedProfile profile; + + @Mock private Account account; + + @Mock private RateLimiter rateLimiter; + + @Mock private ProfileBadgeConverter profileBadgeConverter; + + @Mock private ServerZkProfileOperations serverZkProfileOperations; - private ProfileGrpc.ProfileBlockingStub profileBlockingStub; - - @RegisterExtension - static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); - - @BeforeEach - void setup() { - accountsManager = mock(AccountsManager.class); - profilesManager = mock(ProfilesManager.class); - dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class); - asyncS3client = mock(S3AsyncClient.class); - profile = mock(VersionedProfile.class); - account = mock(Account.class); - rateLimiter = mock(RateLimiter.class); - profileBadgeConverter = mock(ProfileBadgeConverter.class); - serverZkProfileOperations = mock(ServerZkProfileOperations.class); + @Override + protected ProfileGrpcService createServiceBeforeEachTest() { @SuppressWarnings("unchecked") final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); final PolicySigner policySigner = new PolicySigner("accessSecret", "us-west-1"); @@ -170,30 +170,6 @@ public class ProfileGrpcServiceTest { metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us"); metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3"); - profileBlockingStub = ProfileGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); - - final ProfileGrpcService profileGrpcService = new ProfileGrpcService( - Clock.systemUTC(), - accountsManager, - profilesManager, - dynamicConfigurationManager, - badgesConfiguration, - asyncS3client, - policyGenerator, - policySigner, - profileBadgeConverter, - rateLimiters, - serverZkProfileOperations, - S3_BUCKET - ); - - final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); - mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID); - - GRPC_SERVER_EXTENSION.getServiceRegistry() - .addService(ServerInterceptors.intercept(profileGrpcService, mockAuthenticationInterceptor)); - when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter); when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty()); @@ -218,6 +194,21 @@ public class ProfileGrpcServiceTest { when(dynamicPaymentsConfiguration.getDisallowedPrefixes()).thenReturn(Collections.emptyList()); when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null)); + + return new ProfileGrpcService( + Clock.systemUTC(), + accountsManager, + profilesManager, + dynamicConfigurationManager, + badgesConfiguration, + asyncS3client, + policyGenerator, + policySigner, + profileBadgeConverter, + rateLimiters, + serverZkProfileOperations, + S3_BUCKET + ); } @Test @@ -237,7 +228,7 @@ public class ProfileGrpcServiceTest { .setCommitment(ByteString.copyFrom(commitment)) .build(); - profileBlockingStub.setProfile(request); + authenticatedServiceStub().setProfile(request); final ArgumentCaptor profileArgumentCaptor = ArgumentCaptor.forClass(VersionedProfile.class); @@ -274,7 +265,7 @@ public class ProfileGrpcServiceTest { hasPreviousProfile ? Optional.of(profile) : Optional.empty())); when(profilesManager.setAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - SetProfileResponse response = profileBlockingStub.setProfile(request); + final SetProfileResponse response = authenticatedServiceStub().setProfile(request); if (expectHasS3UploadPath) { assertTrue(response.getAttributes().getPath().startsWith("profiles/")); @@ -312,10 +303,7 @@ public class ProfileGrpcServiceTest { @ParameterizedTest @MethodSource void setProfileInvalidRequestData(final SetProfileRequest request) { - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request)); - - assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setProfile(request)); } private static Stream setProfileInvalidRequestData() throws InvalidInputException{ @@ -386,12 +374,10 @@ public class ProfileGrpcServiceTest { when(profilesManager.getAsync(any(), anyString())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); if (hasExistingPaymentAddress) { - assertDoesNotThrow(() -> profileBlockingStub.setProfile(request), + assertDoesNotThrow(() -> authenticatedServiceStub().setProfile(request), "Payment address changes in disallowed countries should still be allowed if the account already has a valid payment address"); } else { - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request)); - assertEquals(Status.PERMISSION_DENIED.getCode(), exception.getStatus().getCode()); + assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().setProfile(request)); } } @@ -433,7 +419,7 @@ public class ProfileGrpcServiceTest { when(profileBadgeConverter.convert(any(), any(), anyBoolean())).thenReturn(badges); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - final GetUnversionedProfileResponse response = profileBlockingStub.getUnversionedProfile(request); + final GetUnversionedProfileResponse response = authenticatedServiceStub().getUnversionedProfile(request); final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey); final GetUnversionedProfileResponse prototypeExpectedResponse = GetUnversionedProfileResponse.newBuilder() @@ -472,10 +458,7 @@ public class ProfileGrpcServiceTest { .build()) .build(); - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileBlockingStub.getUnversionedProfile(request)); - - assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getUnversionedProfile(request)); } @ParameterizedTest @@ -493,14 +476,7 @@ public class ProfileGrpcServiceTest { .build()) .build(); - final StatusRuntimeException exception = - assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.getUnversionedProfile(request)); - - assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode()); - assertNotNull(exception.getTrailers()); - assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY)); - - verifyNoInteractions(accountsManager); + assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getUnversionedProfile(request), accountsManager); } @ParameterizedTest @@ -531,7 +507,7 @@ public class ProfileGrpcServiceTest { when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); - final GetVersionedProfileResponse response = profileBlockingStub.getVersionedProfile(request); + final GetVersionedProfileResponse response = authenticatedServiceStub().getVersionedProfile(request); final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder() .setName(ByteString.copyFrom(name)) @@ -545,7 +521,6 @@ public class ProfileGrpcServiceTest { assertEquals(expectedResponseBuilder.build(), response); } - private static Stream getVersionedProfile() { return Stream.of( Arguments.of("version1", "version1", true), @@ -553,6 +528,7 @@ public class ProfileGrpcServiceTest { Arguments.of("version1", "version2", false) ); } + @ParameterizedTest @MethodSource void getVersionedProfileAccountOrProfileNotFound(final boolean missingAccount, final boolean missingProfile) { @@ -566,10 +542,7 @@ public class ProfileGrpcServiceTest { when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(missingAccount ? Optional.empty() : Optional.of(account))); when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(missingProfile ? Optional.empty() : Optional.of(profile))); - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileBlockingStub.getVersionedProfile(request)); - - assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getVersionedProfile(request)); } private static Stream getVersionedProfileAccountOrProfileNotFound() { @@ -581,10 +554,7 @@ public class ProfileGrpcServiceTest { @Test void getVersionedProfileRatelimited() { - final Duration retryAfterDuration = Duration.ofMinutes(7); - - when(rateLimiter.validateReactive(any(UUID.class))) - .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false))); + final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(7), false); final GetVersionedProfileRequest request = GetVersionedProfileRequest.newBuilder() .setAccountIdentifier(ServiceIdentifier.newBuilder() @@ -594,15 +564,7 @@ public class ProfileGrpcServiceTest { .setVersion("someVersion") .build(); - final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, - () -> profileBlockingStub.getVersionedProfile(request)); - - assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode()); - assertNotNull(exception.getTrailers()); - assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY)); - - verifyNoInteractions(accountsManager); - verifyNoInteractions(profilesManager); + assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getVersionedProfile(request), accountsManager, profilesManager); } @Test @@ -615,9 +577,7 @@ public class ProfileGrpcServiceTest { .setVersion("someVersion") .build(); - final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, - () -> profileBlockingStub.getVersionedProfile(request)); - assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getVersionedProfile(request)); } @Test @@ -664,7 +624,7 @@ public class ProfileGrpcServiceTest { .setVersion("someVersion") .build(); - final GetExpiringProfileKeyCredentialResponse response = profileBlockingStub.getExpiringProfileKeyCredential(request); + final GetExpiringProfileKeyCredentialResponse response = authenticatedServiceStub().getExpiringProfileKeyCredential(request); assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray()); @@ -677,9 +637,8 @@ public class ProfileGrpcServiceTest { @Test void getExpiringProfileKeyCredentialRateLimited() { - final Duration retryAfterDuration = Duration.ofMinutes(5); - when(rateLimiter.validateReactive(AUTHENTICATED_ACI)) - .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false))); + final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail( + rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(5), false); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); final GetExpiringProfileKeyCredentialRequest request = GetExpiringProfileKeyCredentialRequest.newBuilder() @@ -692,14 +651,7 @@ public class ProfileGrpcServiceTest { .setVersion("someVersion") .build(); - StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, - () -> profileBlockingStub.getExpiringProfileKeyCredential(request)); - - assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode()); - assertNotNull(exception.getTrailers()); - assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY)); - - verifyNoInteractions(profilesManager); + assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request), profilesManager); } @ParameterizedTest @@ -723,10 +675,7 @@ public class ProfileGrpcServiceTest { .setVersion("someVersion") .build(); - final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, - () -> profileBlockingStub.getExpiringProfileKeyCredential(request)); - - assertEquals(Status.Code.NOT_FOUND, statusRuntimeException.getStatus().getCode()); + assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request)); } private static Stream getExpiringProfileKeyCredentialAccountOrProfileNotFound() { @@ -761,10 +710,7 @@ public class ProfileGrpcServiceTest { .setVersion("someVersion") .build(); - StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, - () -> profileBlockingStub.getExpiringProfileKeyCredential(request)); - - assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode()); + assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request)); } private static Stream getExpiringProfileKeyCredentialInvalidArgument() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java new file mode 100644 index 000000000..8d3907022 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/SimpleBaseGrpcTest.java @@ -0,0 +1,146 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static java.util.Objects.requireNonNull; + +import io.grpc.BindableService; +import io.grpc.Channel; +import io.grpc.stub.AbstractBlockingStub; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.UUID; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.mockito.MockitoAnnotations; +import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.storage.Device; + +/** + * Base class for the common case of gRPC services tests. This base class makes some assumptions + * and introduces some constraints on the implementing classes with a goal of simplifying the process + * of creating a test for the most of the gRPC services. + *
    + *
  • + * Test classes extending this class will have to override the {@link #createServiceBeforeEachTest()} method + * with the logic that creates an instance of the service to test. This method is called before each test and should + * contain other setup code that would normally go into {@code @BeforeEach} method. + *
  • + *
  • + * This base class takes care of creating two service stubs: {@code authenticatedServiceStub} and {@code unauthenticatedServiceStub}. + * Normally, those stubs are created by the call to the {@code newBlockingStub()} method on the {@code *Stub} class, e.g.: + *
    CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
    + * In this class, those stubs are created by the {@link #createStub(Channel)} method that has a default implementation that is based on + * figuring out the name of the {@code `*Grpc`} class and invoking {@code `*Grpc.newBlockingStub()`} method with reflection. + *
  • + *
  • + * This class takes care of initializing {@code Mockito} annotations processing, so implementing classes + * can annotate their fields with {@code @Mock} and have those mocks ready by the time {@link #createServiceBeforeEachTest()} is called. + *
  • + *
+ * @param Class of the gRPC service that is being tested. + * @param Class of the gRPC service stub. + */ +public abstract class SimpleBaseGrpcTest> { + + @RegisterExtension + protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_AUTHENTICATED = new GrpcServerExtension(); + + @RegisterExtension + protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_UNAUTHENTICATED = new GrpcServerExtension(); + + protected static final UUID AUTHENTICATED_ACI = UUID.randomUUID(); + + protected static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID; + + private AutoCloseable mocksCloseable; + + private MockAuthenticationInterceptor mockAuthenticationInterceptor; + + private SERVICE service; + + private STUB authenticatedServiceStub; + + private STUB unauthenticatedServiceStub; + + + /** + * This method is invoked before each test and is expected to create an instance of the gRPC service + * that is being tested and also to perform all necessary before-each setup. + *

+ * Extending classes may have their own {@code @BeforeEach} method, but it will be called after this method. + * @return an instance of the gRPC service. + */ + protected abstract SERVICE createServiceBeforeEachTest(); + + /** + * The default implementation of this method is based on figuring out the name of the {@code `*Grpc`} class + * and invoking {@code `*Grpc.newBlockingStub()`} method with reflection. + *

+ * Overriding this method can be helpful if addutional configuration of the stub is required, e.g. adding interceptors: + *

+   * protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
+   *   final Metadata metadata = new Metadata();
+   *   metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
+   *   metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
+   *   return super.createStub(channel)
+   *       .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
+   * }
+   * 
+ * @param channel grpc channel to create create the stub for. + * @return and instance of the service stub. + */ + protected STUB createStub(final Channel channel) throws + ClassNotFoundException, + NoSuchMethodException, + InvocationTargetException, + IllegalAccessException { + final String serviceClassName = service.bindService().getServiceDescriptor().getName(); + final String grpcClassName = serviceClassName + "Grpc"; + final Class grpcClass = ClassLoader.getSystemClassLoader().loadClass(grpcClassName); + final Method newBlockingStubMethod = grpcClass.getMethod("newBlockingStub", Channel.class); + final Object stub = newBlockingStubMethod.invoke(null, channel); + //noinspection unchecked + return (STUB) stub; + } + + @BeforeEach + protected void baseSetup() { + mocksCloseable = MockitoAnnotations.openMocks(this); + service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`"); + mockAuthenticationInterceptor = GrpcTestUtils.setupAuthenticatedExtension( + GRPC_SERVER_EXTENSION_AUTHENTICATED, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service); + GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, service); + try { + authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel()); + unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel()); + } catch (Exception e) { + throw new RuntimeException("Could not create a stub based on the service name. Try overriding `createStub()` method."); + } + } + + @AfterEach + public void releaseMocks() throws Exception { + mocksCloseable.close(); + } + + public MockAuthenticationInterceptor mockAuthenticationInterceptor() { + return mockAuthenticationInterceptor; + } + + protected SERVICE service() { + return service; + } + + protected STUB authenticatedServiceStub() { + return authenticatedServiceStub; + } + + protected STUB unauthenticatedServiceStub() { + return unauthenticatedServiceStub; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java index 3173f8753..9d90e2405 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java @@ -111,7 +111,7 @@ public class RateLimitersTest { final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig(); - assertEquals(expected, limiter.config()); + assertEquals(expected, config(limiter)); } @Test @@ -131,16 +131,16 @@ public class RateLimitersTest { final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig); - assertEquals(initialRateLimiterConfig, limiter.config()); + assertEquals(initialRateLimiterConfig, config(limiter)); - assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config()); - assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config()); + assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter())); + assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter())); limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig); - assertEquals(updatedRateLimiterCongig, limiter.config()); + assertEquals(updatedRateLimiterCongig, config(limiter)); - assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config()); - assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config()); + assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter())); + assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter())); } @Test @@ -161,22 +161,22 @@ public class RateLimitersTest { // test only default is present mapForDynamic.remove(descriptor.id()); mapForStatic.remove(descriptor.id()); - assertEquals(defaultConfig, limiter.config()); + assertEquals(defaultConfig, config(limiter)); // test dynamic and no static mapForDynamic.put(descriptor.id(), configForDynamic); mapForStatic.remove(descriptor.id()); - assertEquals(configForDynamic, limiter.config()); + assertEquals(configForDynamic, config(limiter)); // test dynamic and static mapForDynamic.put(descriptor.id(), configForDynamic); mapForStatic.put(descriptor.id(), configForStatic); - assertEquals(configForDynamic, limiter.config()); + assertEquals(configForDynamic, config(limiter)); // test static, but no dynamic mapForDynamic.remove(descriptor.id()); mapForStatic.put(descriptor.id(), configForStatic); - assertEquals(configForStatic, limiter.config()); + assertEquals(configForStatic, config(limiter)); } private record TestDescriptor(String id) implements RateLimiterDescriptor { @@ -191,4 +191,14 @@ public class RateLimitersTest { return new RateLimiterConfig(1, Duration.ofMinutes(1)); } } + + private static RateLimiterConfig config(final RateLimiter rateLimiter) { + if (rateLimiter instanceof StaticRateLimiter rm) { + return rm.config(); + } + if (rateLimiter instanceof DynamicRateLimiter rm) { + return rm.config(); + } + throw new IllegalArgumentException("Rate limiter is of an unexpected type: " + rateLimiter.getClass().getName()); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java index 16466b8fd..af5a62b3b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java @@ -12,12 +12,14 @@ import static org.mockito.Mockito.doThrow; import java.time.Duration; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.apache.commons.lang3.RandomUtils; import org.mockito.Mockito; import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import reactor.core.publisher.Mono; public final class MockUtils { @@ -46,32 +48,80 @@ public final class MockUtils { } public static void updateRateLimiterResponseToAllow( - final RateLimiters rateLimitersMock, - final RateLimiters.For handle, + final RateLimiter mockRateLimiter, final String input) { - final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); - doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); try { doNothing().when(mockRateLimiter).validate(eq(input)); + doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input)); + doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input)); + } catch (final RateLimitExceededException e) { + throw new RuntimeException(e); + } + } + + public static void updateRateLimiterResponseToAllow( + final RateLimiter mockRateLimiter, + final UUID input) { + try { + doNothing().when(mockRateLimiter).validate(eq(input)); + doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input)); + doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input)); } catch (final RateLimitExceededException e) { throw new RuntimeException(e); } } + public static void updateRateLimiterResponseToAllow( + final RateLimiters rateLimitersMock, + final RateLimiters.For handle, + final String input) { + final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); + doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); + updateRateLimiterResponseToAllow(mockRateLimiter, input); + } + public static void updateRateLimiterResponseToAllow( final RateLimiters rateLimitersMock, final RateLimiters.For handle, final UUID input) { final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); + updateRateLimiterResponseToAllow(mockRateLimiter, input); + } + + public static Duration updateRateLimiterResponseToFail( + final RateLimiter mockRateLimiter, + final String input, + final Duration retryAfter, + final boolean legacyStatusCode) { try { - doNothing().when(mockRateLimiter).validate(eq(input)); + final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode); + doThrow(exception).when(mockRateLimiter).validate(eq(input)); + doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input)); + doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input)); + return retryAfter; } catch (final RateLimitExceededException e) { throw new RuntimeException(e); } } - public static void updateRateLimiterResponseToFail( + public static Duration updateRateLimiterResponseToFail( + final RateLimiter mockRateLimiter, + final UUID input, + final Duration retryAfter, + final boolean legacyStatusCode) { + try { + final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode); + doThrow(exception).when(mockRateLimiter).validate(eq(input)); + doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input)); + doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input)); + return retryAfter; + } catch (final RateLimitExceededException e) { + throw new RuntimeException(e); + } + } + + public static Duration updateRateLimiterResponseToFail( final RateLimiters rateLimitersMock, final RateLimiters.For handle, final String input, @@ -79,14 +129,10 @@ public final class MockUtils { final boolean legacyStatusCode) { final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); - try { - doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input)); - } catch (final RateLimitExceededException e) { - throw new RuntimeException(e); - } + return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode); } - public static void updateRateLimiterResponseToFail( + public static Duration updateRateLimiterResponseToFail( final RateLimiters rateLimitersMock, final RateLimiters.For handle, final UUID input, @@ -94,11 +140,7 @@ public final class MockUtils { final boolean legacyStatusCode) { final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); - try { - doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input)); - } catch (final RateLimitExceededException e) { - throw new RuntimeException(e); - } + return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode); } public static SecretBytes randomSecretBytes(final int size) {