DRY gRPC tests, refactor error mapping

This commit is contained in:
Sergey Skrobotov 2023-09-08 16:08:59 -07:00
parent 29ca544c95
commit 977243ebfd
18 changed files with 637 additions and 525 deletions

View File

@ -120,6 +120,7 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor;
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper;
import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
@ -644,8 +645,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new BasicCredentialAuthenticationInterceptor(new BaseAccountAuthenticator(accountsManager));
final ServerBuilder<?> grpcServer = ServerBuilder.forPort(config.getGrpcPort())
// TODO: specialize metrics with user-agent platform
.intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keys))
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
@ -657,13 +656,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.addFilter("RemoteDeprecationFilter", remoteDeprecationFilter)
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
grpcServer.intercept(new AcceptLanguageInterceptor());
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
// depends on the user-agent context so it has to come first here!
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
grpcServer.intercept(remoteDeprecationFilter);
grpcServer.intercept(new UserAgentInterceptor());
grpcServer
// TODO: specialize metrics with user-agent platform
.intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
.intercept(new ErrorMappingInterceptor())
.intercept(new AcceptLanguageInterceptor())
.intercept(remoteDeprecationFilter)
.intercept(new UserAgentInterceptor());
environment.lifecycle().manage(new GrpcServerManagedWrapper(grpcServer.build()));

View File

@ -1,14 +1,30 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import io.grpc.Metadata;
import io.grpc.Status;
import java.time.Duration;
import java.util.Optional;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.grpc.ConvertibleToGrpcStatus;
public class RateLimitExceededException extends Exception {
public class RateLimitExceededException extends Exception implements ConvertibleToGrpcStatus {
public static final Metadata.Key<Duration> RETRY_AFTER_DURATION_KEY =
Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
@Override
public String toAsciiString(final Duration value) {
return value.toString();
}
@Override
public Duration parseAsciiString(final String serialized) {
return Duration.parse(serialized);
}
});
@Nullable
private final Duration retryDuration;
@ -33,4 +49,19 @@ public class RateLimitExceededException extends Exception {
public boolean isLegacy() {
return legacy;
}
@Override
public Status grpcStatus() {
return Status.RESOURCE_EXHAUSTED;
}
@Override
public Optional<Metadata> grpcMetadata() {
return getRetryDuration()
.map(duration -> {
final Metadata metadata = new Metadata();
metadata.put(RETRY_AFTER_DURATION_KEY, duration);
return metadata;
});
}
}

View File

@ -24,11 +24,6 @@ public class CallingGrpcService extends ReactorCallingGrpc.CallingImplBase {
this.rateLimiters = rateLimiters;
}
@Override
protected Throwable onErrorMap(final Throwable throwable) {
return RateLimitUtil.mapRateLimitExceededException(throwable);
}
@Override
public Mono<GetTurnCredentialsResponse> getTurnCredentials(final GetTurnCredentialsRequest request) {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();

View File

@ -0,0 +1,20 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.Status;
import java.util.Optional;
/**
* Interface to be imlemented by our custom exceptions that are consistently mapped to a gRPC status.
*/
public interface ConvertibleToGrpcStatus {
Status grpcStatus();
Optional<Metadata> grpcMetadata();
}

View File

@ -0,0 +1,49 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
/**
* This interceptor observes responses from the service and if the response status is {@link Status#UNKNOWN}
* and there is a non-null cause which is an instance of {@link ConvertibleToGrpcStatus},
* then status and metadata to be returned to the client is resolved from that object.
* </p>
* This eliminates the need of having each service to override {@code `onErrorMap()`} method for commonly used exceptions.
*/
public class ErrorMappingInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall<>(call) {
@Override
public void close(final Status status, final Metadata trailers) {
// The idea is to only apply the automatic conversion logic in the cases
// when there was no explicit decision by the service to provide a status.
// I.e. if at this point we see anything but the `UNKNOWN`,
// that means that some logic in the service made this decision already
// and automatic conversion may conflict with it.
if (status.getCode().equals(Status.Code.UNKNOWN)
&& status.getCause() instanceof ConvertibleToGrpcStatus convertibleToGrpcStatus) {
super.close(
convertibleToGrpcStatus.grpcStatus(),
convertibleToGrpcStatus.grpcMetadata().orElseGet(Metadata::new)
);
} else {
super.close(status, trailers);
}
}
}, headers);
}
}

View File

@ -74,11 +74,6 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
this.rateLimiters = rateLimiters;
}
@Override
protected Throwable onErrorMap(final Throwable throwable) {
return RateLimitUtil.mapRateLimitExceededException(throwable);
}
@Override
public Mono<GetPreKeyCountResponse> getPreKeyCount(final GetPreKeyCountRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)

View File

@ -100,11 +100,6 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
this.bucket = bucket;
}
@Override
protected Throwable onErrorMap(final Throwable throwable) {
return RateLimitUtil.mapRateLimitExceededException(throwable);
}
@Override
public Mono<SetProfileResponse> setProfile(final SetProfileRequest request) {
validateRequest(request);

View File

@ -1,45 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Duration;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public class RateLimitUtil {
public static final Metadata.Key<Duration> RETRY_AFTER_DURATION_KEY =
Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
@Override
public String toAsciiString(final Duration value) {
return value.toString();
}
@Override
public Duration parseAsciiString(final String serialized) {
return Duration.parse(serialized);
}
});
public static Throwable mapRateLimitExceededException(final Throwable throwable) {
if (throwable instanceof RateLimitExceededException rateLimitExceededException) {
@Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration()
.map(duration -> {
final Metadata metadata = new Metadata();
metadata.put(RETRY_AFTER_DURATION_KEY, duration);
return metadata;
}).orElse(null);
return new StatusException(Status.RESOURCE_EXHAUSTED, trailers);
}
return throwable;
}
}

View File

@ -6,66 +6,41 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mock;
import org.signal.chat.calling.CallingGrpc;
import org.signal.chat.calling.GetTurnCredentialsRequest;
import org.signal.chat.calling.GetTurnCredentialsResponse;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Mono;
import org.whispersystems.textsecuregcm.util.MockUtils;
class CallingGrpcServiceTest {
class CallingGrpcServiceTest extends SimpleBaseGrpcTest<CallingGrpcService, CallingGrpc.CallingBlockingStub> {
@Mock
private TurnTokenGenerator turnTokenGenerator;
@Mock
private RateLimiter turnCredentialRateLimiter;
private CallingGrpc.CallingBlockingStub callingStub;
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
@BeforeEach
void setUp() {
turnTokenGenerator = mock(TurnTokenGenerator.class);
turnCredentialRateLimiter = mock(RateLimiter.class);
@Override
protected CallingGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getTurnLimiter()).thenReturn(turnCredentialRateLimiter);
final CallingGrpcService callingGrpcService = new CallingGrpcService(turnTokenGenerator, rateLimiters);
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(callingGrpcService, mockAuthenticationInterceptor));
callingStub = CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
return new CallingGrpcService(turnTokenGenerator, rateLimiters);
}
@Test
@ -74,10 +49,10 @@ class CallingGrpcServiceTest {
final String password = "test-password";
final List<String> urls = List.of("first", "second");
when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)).thenReturn(Mono.empty());
MockUtils.updateRateLimiterResponseToAllow(turnCredentialRateLimiter, AUTHENTICATED_ACI);
when(turnTokenGenerator.generate(any())).thenReturn(new TurnToken(username, password, urls));
final GetTurnCredentialsResponse response = callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build());
final GetTurnCredentialsResponse response = authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build());
final GetTurnCredentialsResponse expectedResponse = GetTurnCredentialsResponse.newBuilder()
.setUsername(username)
@ -90,20 +65,10 @@ class CallingGrpcServiceTest {
@Test
void getTurnCredentialsRateLimited() {
final Duration retryAfter = Duration.ofMinutes(19);
when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI))
.thenReturn(Mono.error(new RateLimitExceededException(retryAfter, false)));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
final Duration retryAfter = MockUtils.updateRateLimiterResponseToFail(
turnCredentialRateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(19), false);
assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
verify(turnTokenGenerator, never()).generate(any());
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfter, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(turnTokenGenerator);
}
}

View File

@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
@ -14,11 +13,10 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
@ -26,21 +24,19 @@ import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.Mock;
import org.signal.chat.device.ClearPushTokenRequest;
import org.signal.chat.device.ClearPushTokenResponse;
import org.signal.chat.device.DevicesGrpc;
@ -54,42 +50,31 @@ import org.signal.chat.device.SetDeviceNameRequest;
import org.signal.chat.device.SetDeviceNameResponse;
import org.signal.chat.device.SetPushTokenRequest;
import org.signal.chat.device.SetPushTokenResponse;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
class DevicesGrpcServiceTest {
class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, DevicesGrpc.DevicesBlockingStub> {
@Mock
private AccountsManager accountsManager;
@Mock
private KeysManager keysManager;
@Mock
private MessagesManager messagesManager;
@Mock
private Account authenticatedAccount;
private MockAuthenticationInterceptor mockAuthenticationInterceptor;
private DevicesGrpc.DevicesBlockingStub devicesStub;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
@BeforeEach
void setUp() {
accountsManager = mock(AccountsManager.class);
keysManager = mock(KeysManager.class);
messagesManager = mock(MessagesManager.class);
authenticatedAccount = mock(Account.class);
@Override
protected DevicesGrpcService createServiceBeforeEachTest() {
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI))
.thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
@ -117,11 +102,7 @@ class DevicesGrpcServiceTest {
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
final DevicesGrpcService devicesGrpcService = new DevicesGrpcService(accountsManager, keysManager, messagesManager);
devicesStub = DevicesGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(devicesGrpcService, mockAuthenticationInterceptor));
return new DevicesGrpcService(accountsManager, keysManager, messagesManager);
}
@Test
@ -161,14 +142,14 @@ class DevicesGrpcServiceTest {
.build())
.build();
assertEquals(expectedResponse, devicesStub.getDevices(GetDevicesRequest.newBuilder().build()));
assertEquals(expectedResponse, authenticatedServiceStub().getDevices(GetDevicesRequest.newBuilder().build()));
}
@Test
void removeDevice() {
final long deviceId = 17;
final RemoveDeviceResponse ignored = devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(deviceId)
.build());
@ -179,30 +160,23 @@ class DevicesGrpcServiceTest {
@Test
void removeDevicePrimary() {
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
() -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
.setId(1)
.build()));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(1)
.build()));
}
@Test
void removeDeviceNonPrimaryAuthenticated() {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1);
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
() -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
.setId(17)
.build()));
assertEquals(Status.Code.PERMISSION_DENIED, exception.getStatus().getCode());
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1);
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(17)
.build()));
}
@ParameterizedTest
@ValueSource(longs = {Device.MASTER_ID, Device.MASTER_ID + 1})
void setDeviceName(final long deviceId) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
@ -210,7 +184,7 @@ class DevicesGrpcServiceTest {
final byte[] deviceName = new byte[128];
ThreadLocalRandom.current().nextBytes(deviceName);
final SetDeviceNameResponse ignored = devicesStub.setDeviceName(SetDeviceNameRequest.newBuilder()
final SetDeviceNameResponse ignored = authenticatedServiceStub().setDeviceName(SetDeviceNameRequest.newBuilder()
.setName(ByteString.copyFrom(deviceName))
.build());
@ -221,11 +195,7 @@ class DevicesGrpcServiceTest {
@MethodSource
void setDeviceNameIllegalArgument(final SetDeviceNameRequest request) {
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(mock(Device.class)));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> devicesStub.setDeviceName(request));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setDeviceName(request));
}
private static Stream<Arguments> setDeviceNameIllegalArgument() {
@ -248,12 +218,12 @@ class DevicesGrpcServiceTest {
@Nullable final String expectedApnsVoipToken,
@Nullable final String expectedFcmToken) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
final SetPushTokenResponse ignored = devicesStub.setPushToken(request);
final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(device).setApnId(expectedApnsToken);
verify(device).setVoipApnId(expectedApnsVoipToken);
@ -312,7 +282,7 @@ class DevicesGrpcServiceTest {
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
final SetPushTokenResponse ignored = devicesStub.setPushToken(request);
final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
}
@ -352,12 +322,7 @@ class DevicesGrpcServiceTest {
void setPushTokenIllegalArgument(final SetPushTokenRequest request) {
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> devicesStub.setPushToken(request));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setPushToken(request));
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
}
@ -383,7 +348,7 @@ class DevicesGrpcServiceTest {
@Nullable final String fcmToken,
@Nullable final String expectedUserAgent) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
@ -393,7 +358,7 @@ class DevicesGrpcServiceTest {
when(device.getGcmId()).thenReturn(fcmToken);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
final ClearPushTokenResponse ignored = devicesStub.clearPushToken(ClearPushTokenRequest.newBuilder().build());
final ClearPushTokenResponse ignored = authenticatedServiceStub().clearPushToken(ClearPushTokenRequest.newBuilder().build());
verify(device).setApnId(null);
verify(device).setVoipApnId(null);
@ -430,12 +395,12 @@ class DevicesGrpcServiceTest {
@CartesianTest.Values(booleans = {true, false}) final boolean pni,
@CartesianTest.Values(booleans = {true, false}) final boolean paymentActivation) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
final SetCapabilitiesResponse ignored = devicesStub.setCapabilities(SetCapabilitiesRequest.newBuilder()
final SetCapabilitiesResponse ignored = authenticatedServiceStub().setCapabilities(SetCapabilitiesRequest.newBuilder()
.setStorage(storage)
.setTransfer(transfer)
.setPni(pni)

View File

@ -0,0 +1,65 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.verifyNoInteractions;
import io.grpc.BindableService;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Duration;
import java.util.UUID;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.function.Executable;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public final class GrpcTestUtils {
private GrpcTestUtils() {
// noop
}
public static MockAuthenticationInterceptor setupAuthenticatedExtension(
final GrpcServerExtension extension,
final UUID authenticatedAci,
final long authenticatedDeviceId,
final BindableService service) {
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
return mockAuthenticationInterceptor;
}
public static void setupUnauthenticatedExtension(
final GrpcServerExtension extension,
final BindableService service) {
extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, new ErrorMappingInterceptor()));
}
public static void assertStatusException(final Status expected, final Executable serviceCall) {
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
assertEquals(expected.getCode(), exception.getStatus().getCode());
}
public static void assertRateLimitExceeded(
final Duration expectedRetryAfter,
final Executable serviceCall,
final Object... mocksToCheckForNoInteraction) {
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
assertEquals(Status.RESOURCE_EXHAUSTED, exception.getStatus());
assertNotNull(exception.getTrailers());
assertEquals(expectedRetryAfter, exception.getTrailers().get(RateLimitExceededException.RETRY_AFTER_DURATION_KEY));
for (final Object mock: mocksToCheckForNoInteraction) {
verifyNoInteractions(mock);
}
}
}

View File

@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
import io.grpc.Status;
@ -20,11 +21,10 @@ import java.util.Collections;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey;
@ -48,27 +48,18 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class KeysAnonymousGrpcServiceTest {
class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcService, KeysAnonymousGrpc.KeysAnonymousBlockingStub> {
@Mock
private AccountsManager accountsManager;
@Mock
private KeysManager keysManager;
private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
@BeforeEach
void setUp() {
accountsManager = mock(AccountsManager.class);
keysManager = mock(KeysManager.class);
final KeysAnonymousGrpcService keysGrpcService =
new KeysAnonymousGrpcService(accountsManager, keysManager);
keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService);
@Override
protected KeysAnonymousGrpcService createServiceBeforeEachTest() {
return new KeysAnonymousGrpcService(accountsManager, keysManager);
}
@Test
@ -101,7 +92,7 @@ class KeysAnonymousGrpcServiceTest {
when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey);
final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
final GetPreKeysResponse response = unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
@ -152,19 +143,15 @@ class KeysAnonymousGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final StatusRuntimeException statusRuntimeException =
assertThrows(StatusRuntimeException.class,
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(identifier))
.build())
.setDeviceId(Device.MASTER_ID)
.build())
.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(identifier))
.build())
.setDeviceId(Device.MASTER_ID)
.build())
.build()));
}
@Test
@ -174,7 +161,7 @@ class KeysAnonymousGrpcServiceTest {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class,
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
() -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID()))
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
@ -205,19 +192,15 @@ class KeysAnonymousGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class,
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(accountIdentifier))
.build())
.setDeviceId(deviceId)
.build())
.build()));
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(accountIdentifier))
.build())
.setDeviceId(deviceId)
.build())
.build()));
}
}

View File

@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
@ -16,9 +15,10 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Duration;
@ -32,14 +32,13 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey;
@ -56,7 +55,6 @@ import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@ -73,41 +71,34 @@ import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Mono;
class KeysGrpcServiceTest {
private AccountsManager accountsManager;
private KeysManager keysManager;
private RateLimiter preKeysRateLimiter;
private Device authenticatedDevice;
private KeysGrpc.KeysBlockingStub keysStub;
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.KeysBlockingStub> {
private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
protected static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
@BeforeEach
void setUp() {
accountsManager = mock(AccountsManager.class);
keysManager = mock(KeysManager.class);
preKeysRateLimiter = mock(RateLimiter.class);
@Mock
private AccountsManager accountsManager;
@Mock
private KeysManager keysManager;
@Mock
private RateLimiter preKeysRateLimiter;
@Mock
private Device authenticatedDevice;
@Override
protected KeysGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter);
when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters);
keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
authenticatedDevice = mock(Device.class);
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
final Account authenticatedAccount = mock(Account.class);
@ -119,17 +110,13 @@ class KeysGrpcServiceTest {
when(authenticatedAccount.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor));
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount));
when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount));
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
return new KeysGrpcService(accountsManager, keysManager, rateLimiters);
}
@Test
@ -152,7 +139,7 @@ class KeysGrpcServiceTest {
.setPniEcPreKeyCount(3)
.setPniKemPreKeyCount(4)
.build(),
keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
authenticatedServiceStub().getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
}
@ParameterizedTest
@ -168,7 +155,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
authenticatedServiceStub().setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
.setIdentityType(identityType)
.addAllPreKeys(preKeys.stream()
.map(preKey -> EcPreKey.newBuilder()
@ -189,10 +176,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeEcPreKeys(request));
}
private static Stream<Arguments> setOneTimeEcPreKeysWithError() {
@ -242,7 +226,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
keysStub.setOneTimeKemSignedPreKeys(
authenticatedServiceStub().setOneTimeKemSignedPreKeys(
SetOneTimeKemSignedPreKeysRequest.newBuilder()
.setIdentityType(identityType)
.addAllPreKeys(preKeys.stream()
@ -265,10 +249,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeKemSignedPreKeys(request));
}
private static Stream<Arguments> setOneTimeKemSignedPreKeysWithError() {
@ -333,7 +314,7 @@ class KeysGrpcServiceTest {
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair);
//noinspection ResultOfMethodCallIgnored
keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
authenticatedServiceStub().setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
.setIdentityType(identityType)
.setSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(signedPreKey.keyId())
@ -359,7 +340,7 @@ class KeysGrpcServiceTest {
@MethodSource
void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request));
assertThrows(StatusRuntimeException.class, () -> authenticatedServiceStub().setEcSignedPreKey(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
}
@ -416,7 +397,7 @@ class KeysGrpcServiceTest {
final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair);
//noinspection ResultOfMethodCallIgnored
keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
authenticatedServiceStub().setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
.setIdentityType(identityType)
.setSignedPreKey(KemSignedPreKey.newBuilder()
.setKeyId(lastResortPreKey.keyId())
@ -437,10 +418,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setKemLastResortPreKey(request));
}
private static Stream<Arguments> setLastResortPreKeyWithError() {
@ -528,7 +506,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
{
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier))
@ -563,7 +541,7 @@ class KeysGrpcServiceTest {
when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
{
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier))
@ -606,15 +584,12 @@ class KeysGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.build())
.build()));
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.build())
.build()));
}
@ParameterizedTest
@ -631,16 +606,13 @@ class KeysGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(accountIdentifier))
.build())
.setDeviceId(deviceId)
.build()));
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(accountIdentifier))
.build())
.setDeviceId(deviceId)
.build()));
}
@Test
@ -655,22 +627,15 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final Duration retryAfterDuration = Duration.ofMinutes(7);
when(preKeysRateLimiter.validateReactive(anyString()))
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.build())
.build()));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.build())
.build()));
verifyNoInteractions(accountsManager);
}
}

View File

@ -1,19 +1,27 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.stub.MetadataUtils;
import java.lang.reflect.InvocationTargetException;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Instant;
@ -24,14 +32,12 @@ import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.MetadataUtils;
import org.junit.jupiter.api.BeforeEach;
import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.signal.chat.common.IdentityType;
import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.profile.CredentialType;
@ -72,43 +78,41 @@ import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import javax.annotation.Nullable;
public class ProfileAnonymousGrpcServiceTest {
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@Mock
private Account account;
@Mock
private AccountsManager accountsManager;
@Mock
private ProfilesManager profilesManager;
@Mock
private ProfileBadgeConverter profileBadgeConverter;
private ProfileAnonymousGrpc.ProfileAnonymousBlockingStub profileAnonymousBlockingStub;
@Mock
private ServerZkProfileOperations serverZkProfileOperations;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
@BeforeEach
void setup() {
account = mock(Account.class);
accountsManager = mock(AccountsManager.class);
profilesManager = mock(ProfilesManager.class);
profileBadgeConverter = mock(ProfileBadgeConverter.class);
serverZkProfileOperations = mock(ServerZkProfileOperations.class);
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
profileAnonymousBlockingStub = ProfileAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
final ProfileAnonymousGrpcService profileAnonymousGrpcService = new ProfileAnonymousGrpcService(
@Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
return new ProfileAnonymousGrpcService(
accountsManager,
profilesManager,
profileBadgeConverter,
serverZkProfileOperations
);
}
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(profileAnonymousGrpcService);
@Override
protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}
@Test
@ -151,7 +155,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
final GetUnversionedProfileResponse response = profileAnonymousBlockingStub.getUnversionedProfile(request);
final GetUnversionedProfileResponse response = unauthenticatedServiceStub().getUnversionedProfile(request);
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
final GetUnversionedProfileResponse expectedResponse = GetUnversionedProfileResponse.newBuilder()
@ -189,10 +193,7 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
}
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileAnonymousBlockingStub.getUnversionedProfile(requestBuilder.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getUnversionedProfile(requestBuilder.build()));
}
private static Stream<Arguments> getUnversionedProfileUnauthenticated() {
@ -242,7 +243,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
final GetVersionedProfileResponse response = profileAnonymousBlockingStub.getVersionedProfile(request);
final GetVersionedProfileResponse response = unauthenticatedServiceStub().getVersionedProfile(request);
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
.setName(ByteString.copyFrom(name))
@ -287,10 +288,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileAnonymousBlockingStub.getVersionedProfile(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getVersionedProfile(request));
}
@ParameterizedTest
@ -318,18 +316,15 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
}
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileAnonymousBlockingStub.getVersionedProfile(requestBuilder.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getVersionedProfile(requestBuilder.build()));
}
private static Stream<Arguments> getVersionedProfileUnauthenticated() {
return Stream.of(
Arguments.of(true, false),
Arguments.of(false, true)
);
}
@Test
void getVersionedProfilePniInvalidArgument() {
final byte[] unidentifiedAccessKey = new byte[16];
@ -346,10 +341,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileAnonymousBlockingStub.getVersionedProfile(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getVersionedProfile(request));
}
@Test
@ -404,7 +396,7 @@ public class ProfileAnonymousGrpcServiceTest {
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.build();
final GetExpiringProfileKeyCredentialResponse response = profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request);
final GetExpiringProfileKeyCredentialResponse response = unauthenticatedServiceStub().getExpiringProfileKeyCredential(request);
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
@ -442,10 +434,7 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
}
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(requestBuilder.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(requestBuilder.build()));
verifyNoInteractions(profilesManager);
}
@ -483,10 +472,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
@ParameterizedTest
@ -521,10 +507,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build())
.build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() {

View File

@ -10,8 +10,6 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
@ -21,15 +19,14 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString;
import io.grpc.Metadata;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.MetadataUtils;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Clock;
@ -44,15 +41,14 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.signal.chat.common.IdentityType;
import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.profile.CredentialType;
@ -82,7 +78,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest;
import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext;
import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.badges.ProfileBadgeConverter;
import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
@ -100,49 +95,54 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
public class ProfileGrpcServiceTest {
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcService, ProfileGrpc.ProfileBlockingStub> {
private static final String S3_BUCKET = "profileBucket";
private static final String VERSION = "someVersion";
private static final byte[] VALID_NAME = new byte[81];
@Mock
private AccountsManager accountsManager;
@Mock
private ProfilesManager profilesManager;
@Mock
private DynamicPaymentsConfiguration dynamicPaymentsConfiguration;
@Mock
private S3AsyncClient asyncS3client;
@Mock
private VersionedProfile profile;
@Mock
private Account account;
@Mock
private RateLimiter rateLimiter;
@Mock
private ProfileBadgeConverter profileBadgeConverter;
@Mock
private ServerZkProfileOperations serverZkProfileOperations;
private ProfileGrpc.ProfileBlockingStub profileBlockingStub;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
profilesManager = mock(ProfilesManager.class);
dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class);
asyncS3client = mock(S3AsyncClient.class);
profile = mock(VersionedProfile.class);
account = mock(Account.class);
rateLimiter = mock(RateLimiter.class);
profileBadgeConverter = mock(ProfileBadgeConverter.class);
serverZkProfileOperations = mock(ServerZkProfileOperations.class);
@Override
protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
final PolicySigner policySigner = new PolicySigner("accessSecret", "us-west-1");
@ -170,30 +170,6 @@ public class ProfileGrpcServiceTest {
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
profileBlockingStub = ProfileGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
final ProfileGrpcService profileGrpcService = new ProfileGrpcService(
Clock.systemUTC(),
accountsManager,
profilesManager,
dynamicConfigurationManager,
badgesConfiguration,
asyncS3client,
policyGenerator,
policySigner,
profileBadgeConverter,
rateLimiters,
serverZkProfileOperations,
S3_BUCKET
);
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(profileGrpcService, mockAuthenticationInterceptor));
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
@ -218,6 +194,21 @@ public class ProfileGrpcServiceTest {
when(dynamicPaymentsConfiguration.getDisallowedPrefixes()).thenReturn(Collections.emptyList());
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
return new ProfileGrpcService(
Clock.systemUTC(),
accountsManager,
profilesManager,
dynamicConfigurationManager,
badgesConfiguration,
asyncS3client,
policyGenerator,
policySigner,
profileBadgeConverter,
rateLimiters,
serverZkProfileOperations,
S3_BUCKET
);
}
@Test
@ -237,7 +228,7 @@ public class ProfileGrpcServiceTest {
.setCommitment(ByteString.copyFrom(commitment))
.build();
profileBlockingStub.setProfile(request);
authenticatedServiceStub().setProfile(request);
final ArgumentCaptor<VersionedProfile> profileArgumentCaptor = ArgumentCaptor.forClass(VersionedProfile.class);
@ -274,7 +265,7 @@ public class ProfileGrpcServiceTest {
hasPreviousProfile ? Optional.of(profile) : Optional.empty()));
when(profilesManager.setAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
SetProfileResponse response = profileBlockingStub.setProfile(request);
final SetProfileResponse response = authenticatedServiceStub().setProfile(request);
if (expectHasS3UploadPath) {
assertTrue(response.getAttributes().getPath().startsWith("profiles/"));
@ -312,10 +303,7 @@ public class ProfileGrpcServiceTest {
@ParameterizedTest
@MethodSource
void setProfileInvalidRequestData(final SetProfileRequest request) {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setProfile(request));
}
private static Stream<Arguments> setProfileInvalidRequestData() throws InvalidInputException{
@ -386,12 +374,10 @@ public class ProfileGrpcServiceTest {
when(profilesManager.getAsync(any(), anyString())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
if (hasExistingPaymentAddress) {
assertDoesNotThrow(() -> profileBlockingStub.setProfile(request),
assertDoesNotThrow(() -> authenticatedServiceStub().setProfile(request),
"Payment address changes in disallowed countries should still be allowed if the account already has a valid payment address");
} else {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
assertEquals(Status.PERMISSION_DENIED.getCode(), exception.getStatus().getCode());
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().setProfile(request));
}
}
@ -433,7 +419,7 @@ public class ProfileGrpcServiceTest {
when(profileBadgeConverter.convert(any(), any(), anyBoolean())).thenReturn(badges);
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
final GetUnversionedProfileResponse response = profileBlockingStub.getUnversionedProfile(request);
final GetUnversionedProfileResponse response = authenticatedServiceStub().getUnversionedProfile(request);
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
final GetUnversionedProfileResponse prototypeExpectedResponse = GetUnversionedProfileResponse.newBuilder()
@ -472,10 +458,7 @@ public class ProfileGrpcServiceTest {
.build())
.build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileBlockingStub.getUnversionedProfile(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getUnversionedProfile(request));
}
@ParameterizedTest
@ -493,14 +476,7 @@ public class ProfileGrpcServiceTest {
.build())
.build();
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.getUnversionedProfile(request));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(accountsManager);
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getUnversionedProfile(request), accountsManager);
}
@ParameterizedTest
@ -531,7 +507,7 @@ public class ProfileGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
final GetVersionedProfileResponse response = profileBlockingStub.getVersionedProfile(request);
final GetVersionedProfileResponse response = authenticatedServiceStub().getVersionedProfile(request);
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
.setName(ByteString.copyFrom(name))
@ -545,7 +521,6 @@ public class ProfileGrpcServiceTest {
assertEquals(expectedResponseBuilder.build(), response);
}
private static Stream<Arguments> getVersionedProfile() {
return Stream.of(
Arguments.of("version1", "version1", true),
@ -553,6 +528,7 @@ public class ProfileGrpcServiceTest {
Arguments.of("version1", "version2", false)
);
}
@ParameterizedTest
@MethodSource
void getVersionedProfileAccountOrProfileNotFound(final boolean missingAccount, final boolean missingProfile) {
@ -566,10 +542,7 @@ public class ProfileGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(missingAccount ? Optional.empty() : Optional.of(account)));
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(missingProfile ? Optional.empty() : Optional.of(profile)));
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileBlockingStub.getVersionedProfile(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getVersionedProfile(request));
}
private static Stream<Arguments> getVersionedProfileAccountOrProfileNotFound() {
@ -581,10 +554,7 @@ public class ProfileGrpcServiceTest {
@Test
void getVersionedProfileRatelimited() {
final Duration retryAfterDuration = Duration.ofMinutes(7);
when(rateLimiter.validateReactive(any(UUID.class)))
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(7), false);
final GetVersionedProfileRequest request = GetVersionedProfileRequest.newBuilder()
.setAccountIdentifier(ServiceIdentifier.newBuilder()
@ -594,15 +564,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
() -> profileBlockingStub.getVersionedProfile(request));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(accountsManager);
verifyNoInteractions(profilesManager);
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getVersionedProfile(request), accountsManager, profilesManager);
}
@Test
@ -615,9 +577,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
() -> profileBlockingStub.getVersionedProfile(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getVersionedProfile(request));
}
@Test
@ -664,7 +624,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
final GetExpiringProfileKeyCredentialResponse response = profileBlockingStub.getExpiringProfileKeyCredential(request);
final GetExpiringProfileKeyCredentialResponse response = authenticatedServiceStub().getExpiringProfileKeyCredential(request);
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
@ -677,9 +637,8 @@ public class ProfileGrpcServiceTest {
@Test
void getExpiringProfileKeyCredentialRateLimited() {
final Duration retryAfterDuration = Duration.ofMinutes(5);
when(rateLimiter.validateReactive(AUTHENTICATED_ACI))
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(
rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(5), false);
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
final GetExpiringProfileKeyCredentialRequest request = GetExpiringProfileKeyCredentialRequest.newBuilder()
@ -692,14 +651,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(profilesManager);
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request), profilesManager);
}
@ParameterizedTest
@ -723,10 +675,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.Code.NOT_FOUND, statusRuntimeException.getStatus().getCode());
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
private static Stream<Arguments> getExpiringProfileKeyCredentialAccountOrProfileNotFound() {
@ -761,10 +710,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion")
.build();
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
}
private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() {

View File

@ -0,0 +1,146 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static java.util.Objects.requireNonNull;
import io.grpc.BindableService;
import io.grpc.Channel;
import io.grpc.stub.AbstractBlockingStub;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.UUID;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.MockitoAnnotations;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.storage.Device;
/**
* Base class for the common case of gRPC services tests. This base class makes some assumptions
* and introduces some constraints on the implementing classes with a goal of simplifying the process
* of creating a test for the most of the gRPC services.
* <ul>
* <li>
* Test classes extending this class will have to override the {@link #createServiceBeforeEachTest()} method
* with the logic that creates an instance of the service to test. This method is called before each test and should
* contain other setup code that would normally go into {@code @BeforeEach} method.
* </li>
* <li>
* This base class takes care of creating two service stubs: {@code authenticatedServiceStub} and {@code unauthenticatedServiceStub}.
* Normally, those stubs are created by the call to the {@code newBlockingStub()} method on the {@code *Stub} class, e.g.:
* <pre>CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());</pre>
* In this class, those stubs are created by the {@link #createStub(Channel)} method that has a default implementation that is based on
* figuring out the name of the {@code `*Grpc`} class and invoking {@code `*Grpc.newBlockingStub()`} method with reflection.
* </li>
* <li>
* This class takes care of initializing {@code Mockito} annotations processing, so implementing classes
* can annotate their fields with {@code @Mock} and have those mocks ready by the time {@link #createServiceBeforeEachTest()} is called.
* </li>
* </ul>
* @param <SERVICE> Class of the gRPC service that is being tested.
* @param <STUB> Class of the gRPC service stub.
*/
public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB extends AbstractBlockingStub<STUB>> {
@RegisterExtension
protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_AUTHENTICATED = new GrpcServerExtension();
@RegisterExtension
protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_UNAUTHENTICATED = new GrpcServerExtension();
protected static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
protected static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
private AutoCloseable mocksCloseable;
private MockAuthenticationInterceptor mockAuthenticationInterceptor;
private SERVICE service;
private STUB authenticatedServiceStub;
private STUB unauthenticatedServiceStub;
/**
* This method is invoked before each test and is expected to create an instance of the gRPC service
* that is being tested and also to perform all necessary before-each setup.
* </p>
* Extending classes may have their own {@code @BeforeEach} method, but it will be called after this method.
* @return an instance of the gRPC service.
*/
protected abstract SERVICE createServiceBeforeEachTest();
/**
* The default implementation of this method is based on figuring out the name of the {@code `*Grpc`} class
* and invoking {@code `*Grpc.newBlockingStub()`} method with reflection.
* <p>
* Overriding this method can be helpful if addutional configuration of the stub is required, e.g. adding interceptors:
* <pre>
* protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
* final Metadata metadata = new Metadata();
* metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
* metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
* return super.createStub(channel)
* .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
* }
* </pre>
* @param channel grpc channel to create create the stub for.
* @return and instance of the service stub.
*/
protected STUB createStub(final Channel channel) throws
ClassNotFoundException,
NoSuchMethodException,
InvocationTargetException,
IllegalAccessException {
final String serviceClassName = service.bindService().getServiceDescriptor().getName();
final String grpcClassName = serviceClassName + "Grpc";
final Class<?> grpcClass = ClassLoader.getSystemClassLoader().loadClass(grpcClassName);
final Method newBlockingStubMethod = grpcClass.getMethod("newBlockingStub", Channel.class);
final Object stub = newBlockingStubMethod.invoke(null, channel);
//noinspection unchecked
return (STUB) stub;
}
@BeforeEach
protected void baseSetup() {
mocksCloseable = MockitoAnnotations.openMocks(this);
service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`");
mockAuthenticationInterceptor = GrpcTestUtils.setupAuthenticatedExtension(
GRPC_SERVER_EXTENSION_AUTHENTICATED, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, service);
try {
authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel());
} catch (Exception e) {
throw new RuntimeException("Could not create a stub based on the service name. Try overriding `createStub()` method.");
}
}
@AfterEach
public void releaseMocks() throws Exception {
mocksCloseable.close();
}
public MockAuthenticationInterceptor mockAuthenticationInterceptor() {
return mockAuthenticationInterceptor;
}
protected SERVICE service() {
return service;
}
protected STUB authenticatedServiceStub() {
return authenticatedServiceStub;
}
protected STUB unauthenticatedServiceStub() {
return unauthenticatedServiceStub;
}
}

View File

@ -111,7 +111,7 @@ public class RateLimitersTest {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config());
assertEquals(expected, config(limiter));
}
@Test
@ -131,16 +131,16 @@ public class RateLimitersTest {
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
assertEquals(initialRateLimiterConfig, limiter.config());
assertEquals(initialRateLimiterConfig, config(limiter));
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
assertEquals(updatedRateLimiterCongig, limiter.config());
assertEquals(updatedRateLimiterCongig, config(limiter));
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
}
@Test
@ -161,22 +161,22 @@ public class RateLimitersTest {
// test only default is present
mapForDynamic.remove(descriptor.id());
mapForStatic.remove(descriptor.id());
assertEquals(defaultConfig, limiter.config());
assertEquals(defaultConfig, config(limiter));
// test dynamic and no static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.remove(descriptor.id());
assertEquals(configForDynamic, limiter.config());
assertEquals(configForDynamic, config(limiter));
// test dynamic and static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForDynamic, limiter.config());
assertEquals(configForDynamic, config(limiter));
// test static, but no dynamic
mapForDynamic.remove(descriptor.id());
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForStatic, limiter.config());
assertEquals(configForStatic, config(limiter));
}
private record TestDescriptor(String id) implements RateLimiterDescriptor {
@ -191,4 +191,14 @@ public class RateLimitersTest {
return new RateLimiterConfig(1, Duration.ofMinutes(1));
}
}
private static RateLimiterConfig config(final RateLimiter rateLimiter) {
if (rateLimiter instanceof StaticRateLimiter rm) {
return rm.config();
}
if (rateLimiter instanceof DynamicRateLimiter rm) {
return rm.config();
}
throw new IllegalArgumentException("Rate limiter is of an unexpected type: " + rateLimiter.getClass().getName());
}
}

View File

@ -12,12 +12,14 @@ import static org.mockito.Mockito.doThrow;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.RandomUtils;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import reactor.core.publisher.Mono;
public final class MockUtils {
@ -46,32 +48,80 @@ public final class MockUtils {
}
public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final RateLimiter mockRateLimiter,
final String input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doNothing().when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static void updateRateLimiterResponseToAllow(
final RateLimiter mockRateLimiter,
final UUID input) {
try {
doNothing().when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final String input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
updateRateLimiterResponseToAllow(mockRateLimiter, input);
}
public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final UUID input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
updateRateLimiterResponseToAllow(mockRateLimiter, input);
}
public static Duration updateRateLimiterResponseToFail(
final RateLimiter mockRateLimiter,
final String input,
final Duration retryAfter,
final boolean legacyStatusCode) {
try {
doNothing().when(mockRateLimiter).validate(eq(input));
final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
doThrow(exception).when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
return retryAfter;
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static void updateRateLimiterResponseToFail(
public static Duration updateRateLimiterResponseToFail(
final RateLimiter mockRateLimiter,
final UUID input,
final Duration retryAfter,
final boolean legacyStatusCode) {
try {
final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
doThrow(exception).when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
return retryAfter;
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static Duration updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final String input,
@ -79,14 +129,10 @@ public final class MockUtils {
final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
}
public static void updateRateLimiterResponseToFail(
public static Duration updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final UUID input,
@ -94,11 +140,7 @@ public final class MockUtils {
final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
}
public static SecretBytes randomSecretBytes(final int size) {