DRY gRPC tests, refactor error mapping

This commit is contained in:
Sergey Skrobotov 2023-09-08 16:08:59 -07:00
parent 29ca544c95
commit 977243ebfd
18 changed files with 637 additions and 525 deletions

View File

@ -120,6 +120,7 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor; import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor;
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper; import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper;
import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.KeysGrpcService; import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
@ -644,8 +645,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new BasicCredentialAuthenticationInterceptor(new BaseAccountAuthenticator(accountsManager)); new BasicCredentialAuthenticationInterceptor(new BaseAccountAuthenticator(accountsManager));
final ServerBuilder<?> grpcServer = ServerBuilder.forPort(config.getGrpcPort()) final ServerBuilder<?> grpcServer = ServerBuilder.forPort(config.getGrpcPort())
// TODO: specialize metrics with user-agent platform
.intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor)) .addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keys)) .addService(new KeysAnonymousGrpcService(accountsManager, keys))
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager, .addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
@ -657,13 +656,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.addFilter("RemoteDeprecationFilter", remoteDeprecationFilter) .addFilter("RemoteDeprecationFilter", remoteDeprecationFilter)
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
grpcServer.intercept(new AcceptLanguageInterceptor());
// Note: interceptors run in the reverse order they are added; the remote deprecation filter // Note: interceptors run in the reverse order they are added; the remote deprecation filter
// depends on the user-agent context so it has to come first here! // depends on the user-agent context so it has to come first here!
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor- // http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
grpcServer.intercept(remoteDeprecationFilter); grpcServer
grpcServer.intercept(new UserAgentInterceptor()); // TODO: specialize metrics with user-agent platform
.intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
.intercept(new ErrorMappingInterceptor())
.intercept(new AcceptLanguageInterceptor())
.intercept(remoteDeprecationFilter)
.intercept(new UserAgentInterceptor());
environment.lifecycle().manage(new GrpcServerManagedWrapper(grpcServer.build())); environment.lifecycle().manage(new GrpcServerManagedWrapper(grpcServer.build()));

View File

@ -1,14 +1,30 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import io.grpc.Metadata;
import io.grpc.Status;
import java.time.Duration; import java.time.Duration;
import java.util.Optional; import java.util.Optional;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.grpc.ConvertibleToGrpcStatus;
public class RateLimitExceededException extends Exception { public class RateLimitExceededException extends Exception implements ConvertibleToGrpcStatus {
public static final Metadata.Key<Duration> RETRY_AFTER_DURATION_KEY =
Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
@Override
public String toAsciiString(final Duration value) {
return value.toString();
}
@Override
public Duration parseAsciiString(final String serialized) {
return Duration.parse(serialized);
}
});
@Nullable @Nullable
private final Duration retryDuration; private final Duration retryDuration;
@ -33,4 +49,19 @@ public class RateLimitExceededException extends Exception {
public boolean isLegacy() { public boolean isLegacy() {
return legacy; return legacy;
} }
@Override
public Status grpcStatus() {
return Status.RESOURCE_EXHAUSTED;
}
@Override
public Optional<Metadata> grpcMetadata() {
return getRetryDuration()
.map(duration -> {
final Metadata metadata = new Metadata();
metadata.put(RETRY_AFTER_DURATION_KEY, duration);
return metadata;
});
}
} }

View File

@ -24,11 +24,6 @@ public class CallingGrpcService extends ReactorCallingGrpc.CallingImplBase {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
} }
@Override
protected Throwable onErrorMap(final Throwable throwable) {
return RateLimitUtil.mapRateLimitExceededException(throwable);
}
@Override @Override
public Mono<GetTurnCredentialsResponse> getTurnCredentials(final GetTurnCredentialsRequest request) { public Mono<GetTurnCredentialsResponse> getTurnCredentials(final GetTurnCredentialsRequest request) {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();

View File

@ -0,0 +1,20 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.Status;
import java.util.Optional;
/**
* Interface to be imlemented by our custom exceptions that are consistently mapped to a gRPC status.
*/
public interface ConvertibleToGrpcStatus {
Status grpcStatus();
Optional<Metadata> grpcMetadata();
}

View File

@ -0,0 +1,49 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
/**
* This interceptor observes responses from the service and if the response status is {@link Status#UNKNOWN}
* and there is a non-null cause which is an instance of {@link ConvertibleToGrpcStatus},
* then status and metadata to be returned to the client is resolved from that object.
* </p>
* This eliminates the need of having each service to override {@code `onErrorMap()`} method for commonly used exceptions.
*/
public class ErrorMappingInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall<>(call) {
@Override
public void close(final Status status, final Metadata trailers) {
// The idea is to only apply the automatic conversion logic in the cases
// when there was no explicit decision by the service to provide a status.
// I.e. if at this point we see anything but the `UNKNOWN`,
// that means that some logic in the service made this decision already
// and automatic conversion may conflict with it.
if (status.getCode().equals(Status.Code.UNKNOWN)
&& status.getCause() instanceof ConvertibleToGrpcStatus convertibleToGrpcStatus) {
super.close(
convertibleToGrpcStatus.grpcStatus(),
convertibleToGrpcStatus.grpcMetadata().orElseGet(Metadata::new)
);
} else {
super.close(status, trailers);
}
}
}, headers);
}
}

View File

@ -74,11 +74,6 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
} }
@Override
protected Throwable onErrorMap(final Throwable throwable) {
return RateLimitUtil.mapRateLimitExceededException(throwable);
}
@Override @Override
public Mono<GetPreKeyCountResponse> getPreKeyCount(final GetPreKeyCountRequest request) { public Mono<GetPreKeyCountResponse> getPreKeyCount(final GetPreKeyCountRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)

View File

@ -100,11 +100,6 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
this.bucket = bucket; this.bucket = bucket;
} }
@Override
protected Throwable onErrorMap(final Throwable throwable) {
return RateLimitUtil.mapRateLimitExceededException(throwable);
}
@Override @Override
public Mono<SetProfileResponse> setProfile(final SetProfileRequest request) { public Mono<SetProfileResponse> setProfile(final SetProfileRequest request) {
validateRequest(request); validateRequest(request);

View File

@ -1,45 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Duration;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public class RateLimitUtil {
public static final Metadata.Key<Duration> RETRY_AFTER_DURATION_KEY =
Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
@Override
public String toAsciiString(final Duration value) {
return value.toString();
}
@Override
public Duration parseAsciiString(final String serialized) {
return Duration.parse(serialized);
}
});
public static Throwable mapRateLimitExceededException(final Throwable throwable) {
if (throwable instanceof RateLimitExceededException rateLimitExceededException) {
@Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration()
.map(duration -> {
final Metadata metadata = new Metadata();
metadata.put(RETRY_AFTER_DURATION_KEY, duration);
return metadata;
}).orElse(null);
return new StatusException(Status.RESOURCE_EXHAUSTED, trailers);
}
return throwable;
}
}

View File

@ -6,66 +6,41 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.Mock;
import org.signal.chat.calling.CallingGrpc; import org.signal.chat.calling.CallingGrpc;
import org.signal.chat.calling.GetTurnCredentialsRequest; import org.signal.chat.calling.GetTurnCredentialsRequest;
import org.signal.chat.calling.GetTurnCredentialsResponse; import org.signal.chat.calling.GetTurnCredentialsResponse;
import org.whispersystems.textsecuregcm.auth.TurnToken; import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.MockUtils;
import reactor.core.publisher.Mono;
class CallingGrpcServiceTest { class CallingGrpcServiceTest extends SimpleBaseGrpcTest<CallingGrpcService, CallingGrpc.CallingBlockingStub> {
@Mock
private TurnTokenGenerator turnTokenGenerator; private TurnTokenGenerator turnTokenGenerator;
@Mock
private RateLimiter turnCredentialRateLimiter; private RateLimiter turnCredentialRateLimiter;
private CallingGrpc.CallingBlockingStub callingStub;
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
@BeforeEach
void setUp() {
turnTokenGenerator = mock(TurnTokenGenerator.class);
turnCredentialRateLimiter = mock(RateLimiter.class);
@Override
protected CallingGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class); final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getTurnLimiter()).thenReturn(turnCredentialRateLimiter); when(rateLimiters.getTurnLimiter()).thenReturn(turnCredentialRateLimiter);
return new CallingGrpcService(turnTokenGenerator, rateLimiters);
final CallingGrpcService callingGrpcService = new CallingGrpcService(turnTokenGenerator, rateLimiters);
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(callingGrpcService, mockAuthenticationInterceptor));
callingStub = CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
} }
@Test @Test
@ -74,10 +49,10 @@ class CallingGrpcServiceTest {
final String password = "test-password"; final String password = "test-password";
final List<String> urls = List.of("first", "second"); final List<String> urls = List.of("first", "second");
when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)).thenReturn(Mono.empty()); MockUtils.updateRateLimiterResponseToAllow(turnCredentialRateLimiter, AUTHENTICATED_ACI);
when(turnTokenGenerator.generate(any())).thenReturn(new TurnToken(username, password, urls)); when(turnTokenGenerator.generate(any())).thenReturn(new TurnToken(username, password, urls));
final GetTurnCredentialsResponse response = callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()); final GetTurnCredentialsResponse response = authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build());
final GetTurnCredentialsResponse expectedResponse = GetTurnCredentialsResponse.newBuilder() final GetTurnCredentialsResponse expectedResponse = GetTurnCredentialsResponse.newBuilder()
.setUsername(username) .setUsername(username)
@ -90,20 +65,10 @@ class CallingGrpcServiceTest {
@Test @Test
void getTurnCredentialsRateLimited() { void getTurnCredentialsRateLimited() {
final Duration retryAfter = Duration.ofMinutes(19); final Duration retryAfter = MockUtils.updateRateLimiterResponseToFail(
turnCredentialRateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(19), false);
when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)) assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
.thenReturn(Mono.error(new RateLimitExceededException(retryAfter, false)));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
verify(turnTokenGenerator, never()).generate(any()); verify(turnTokenGenerator, never()).generate(any());
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfter, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(turnTokenGenerator); verifyNoInteractions(turnTokenGenerator);
} }
} }

View File

@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -14,11 +13,10 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.ServerInterceptors;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
@ -26,21 +24,19 @@ import java.time.temporal.ChronoUnit;
import java.util.Base64; import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.Mock;
import org.signal.chat.device.ClearPushTokenRequest; import org.signal.chat.device.ClearPushTokenRequest;
import org.signal.chat.device.ClearPushTokenResponse; import org.signal.chat.device.ClearPushTokenResponse;
import org.signal.chat.device.DevicesGrpc; import org.signal.chat.device.DevicesGrpc;
@ -54,42 +50,31 @@ import org.signal.chat.device.SetDeviceNameRequest;
import org.signal.chat.device.SetDeviceNameResponse; import org.signal.chat.device.SetDeviceNameResponse;
import org.signal.chat.device.SetPushTokenRequest; import org.signal.chat.device.SetPushTokenRequest;
import org.signal.chat.device.SetPushTokenResponse; import org.signal.chat.device.SetPushTokenResponse;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
class DevicesGrpcServiceTest { class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, DevicesGrpc.DevicesBlockingStub> {
@Mock
private AccountsManager accountsManager; private AccountsManager accountsManager;
@Mock
private KeysManager keysManager; private KeysManager keysManager;
@Mock
private MessagesManager messagesManager; private MessagesManager messagesManager;
@Mock
private Account authenticatedAccount; private Account authenticatedAccount;
private MockAuthenticationInterceptor mockAuthenticationInterceptor;
private DevicesGrpc.DevicesBlockingStub devicesStub;
@RegisterExtension @Override
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); protected DevicesGrpcService createServiceBeforeEachTest() {
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
@BeforeEach
void setUp() {
accountsManager = mock(AccountsManager.class);
keysManager = mock(KeysManager.class);
messagesManager = mock(MessagesManager.class);
authenticatedAccount = mock(Account.class);
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI); when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)) when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI))
.thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); .thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
@ -117,11 +102,7 @@ class DevicesGrpcServiceTest {
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
final DevicesGrpcService devicesGrpcService = new DevicesGrpcService(accountsManager, keysManager, messagesManager); return new DevicesGrpcService(accountsManager, keysManager, messagesManager);
devicesStub = DevicesGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(devicesGrpcService, mockAuthenticationInterceptor));
} }
@Test @Test
@ -161,14 +142,14 @@ class DevicesGrpcServiceTest {
.build()) .build())
.build(); .build();
assertEquals(expectedResponse, devicesStub.getDevices(GetDevicesRequest.newBuilder().build())); assertEquals(expectedResponse, authenticatedServiceStub().getDevices(GetDevicesRequest.newBuilder().build()));
} }
@Test @Test
void removeDevice() { void removeDevice() {
final long deviceId = 17; final long deviceId = 17;
final RemoveDeviceResponse ignored = devicesStub.removeDevice(RemoveDeviceRequest.newBuilder() final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(deviceId) .setId(deviceId)
.build()); .build());
@ -179,30 +160,23 @@ class DevicesGrpcServiceTest {
@Test @Test
void removeDevicePrimary() { void removeDevicePrimary() {
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
() -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder() .setId(1)
.setId(1) .build()));
.build()));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
} }
@Test @Test
void removeDeviceNonPrimaryAuthenticated() { void removeDeviceNonPrimaryAuthenticated() {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1); mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1);
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, .setId(17)
() -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder() .build()));
.setId(17)
.build()));
assertEquals(Status.Code.PERMISSION_DENIED, exception.getStatus().getCode());
} }
@ParameterizedTest @ParameterizedTest
@ValueSource(longs = {Device.MASTER_ID, Device.MASTER_ID + 1}) @ValueSource(longs = {Device.MASTER_ID, Device.MASTER_ID + 1})
void setDeviceName(final long deviceId) { void setDeviceName(final long deviceId) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class); final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
@ -210,7 +184,7 @@ class DevicesGrpcServiceTest {
final byte[] deviceName = new byte[128]; final byte[] deviceName = new byte[128];
ThreadLocalRandom.current().nextBytes(deviceName); ThreadLocalRandom.current().nextBytes(deviceName);
final SetDeviceNameResponse ignored = devicesStub.setDeviceName(SetDeviceNameRequest.newBuilder() final SetDeviceNameResponse ignored = authenticatedServiceStub().setDeviceName(SetDeviceNameRequest.newBuilder()
.setName(ByteString.copyFrom(deviceName)) .setName(ByteString.copyFrom(deviceName))
.build()); .build());
@ -221,11 +195,7 @@ class DevicesGrpcServiceTest {
@MethodSource @MethodSource
void setDeviceNameIllegalArgument(final SetDeviceNameRequest request) { void setDeviceNameIllegalArgument(final SetDeviceNameRequest request) {
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(mock(Device.class))); when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(mock(Device.class)));
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setDeviceName(request));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> devicesStub.setDeviceName(request));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
} }
private static Stream<Arguments> setDeviceNameIllegalArgument() { private static Stream<Arguments> setDeviceNameIllegalArgument() {
@ -248,12 +218,12 @@ class DevicesGrpcServiceTest {
@Nullable final String expectedApnsVoipToken, @Nullable final String expectedApnsVoipToken,
@Nullable final String expectedFcmToken) { @Nullable final String expectedFcmToken) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class); final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
final SetPushTokenResponse ignored = devicesStub.setPushToken(request); final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(device).setApnId(expectedApnsToken); verify(device).setApnId(expectedApnsToken);
verify(device).setVoipApnId(expectedApnsVoipToken); verify(device).setVoipApnId(expectedApnsVoipToken);
@ -312,7 +282,7 @@ class DevicesGrpcServiceTest {
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device)); when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
final SetPushTokenResponse ignored = devicesStub.setPushToken(request); final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(accountsManager, never()).updateDevice(any(), anyLong(), any()); verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
} }
@ -352,12 +322,7 @@ class DevicesGrpcServiceTest {
void setPushTokenIllegalArgument(final SetPushTokenRequest request) { void setPushTokenIllegalArgument(final SetPushTokenRequest request) {
final Device device = mock(Device.class); final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device)); when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setPushToken(request));
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> devicesStub.setPushToken(request));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
verify(accountsManager, never()).updateDevice(any(), anyLong(), any()); verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
} }
@ -383,7 +348,7 @@ class DevicesGrpcServiceTest {
@Nullable final String fcmToken, @Nullable final String fcmToken,
@Nullable final String expectedUserAgent) { @Nullable final String expectedUserAgent) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class); final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
@ -393,7 +358,7 @@ class DevicesGrpcServiceTest {
when(device.getGcmId()).thenReturn(fcmToken); when(device.getGcmId()).thenReturn(fcmToken);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
final ClearPushTokenResponse ignored = devicesStub.clearPushToken(ClearPushTokenRequest.newBuilder().build()); final ClearPushTokenResponse ignored = authenticatedServiceStub().clearPushToken(ClearPushTokenRequest.newBuilder().build());
verify(device).setApnId(null); verify(device).setApnId(null);
verify(device).setVoipApnId(null); verify(device).setVoipApnId(null);
@ -430,12 +395,12 @@ class DevicesGrpcServiceTest {
@CartesianTest.Values(booleans = {true, false}) final boolean pni, @CartesianTest.Values(booleans = {true, false}) final boolean pni,
@CartesianTest.Values(booleans = {true, false}) final boolean paymentActivation) { @CartesianTest.Values(booleans = {true, false}) final boolean paymentActivation) {
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class); final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
final SetCapabilitiesResponse ignored = devicesStub.setCapabilities(SetCapabilitiesRequest.newBuilder() final SetCapabilitiesResponse ignored = authenticatedServiceStub().setCapabilities(SetCapabilitiesRequest.newBuilder()
.setStorage(storage) .setStorage(storage)
.setTransfer(transfer) .setTransfer(transfer)
.setPni(pni) .setPni(pni)

View File

@ -0,0 +1,65 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.verifyNoInteractions;
import io.grpc.BindableService;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Duration;
import java.util.UUID;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.function.Executable;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public final class GrpcTestUtils {
private GrpcTestUtils() {
// noop
}
public static MockAuthenticationInterceptor setupAuthenticatedExtension(
final GrpcServerExtension extension,
final UUID authenticatedAci,
final long authenticatedDeviceId,
final BindableService service) {
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
return mockAuthenticationInterceptor;
}
public static void setupUnauthenticatedExtension(
final GrpcServerExtension extension,
final BindableService service) {
extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, new ErrorMappingInterceptor()));
}
public static void assertStatusException(final Status expected, final Executable serviceCall) {
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
assertEquals(expected.getCode(), exception.getStatus().getCode());
}
public static void assertRateLimitExceeded(
final Duration expectedRetryAfter,
final Executable serviceCall,
final Object... mocksToCheckForNoInteraction) {
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
assertEquals(Status.RESOURCE_EXHAUSTED, exception.getStatus());
assertNotNull(exception.getTrailers());
assertEquals(expectedRetryAfter, exception.getTrailers().get(RateLimitExceededException.RETRY_AFTER_DURATION_KEY));
for (final Object mock: mocksToCheckForNoInteraction) {
verifyNoInteractions(mock);
}
}
}

View File

@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Status; import io.grpc.Status;
@ -20,11 +21,10 @@ import java.util.Collections;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey; import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey; import org.signal.chat.common.KemSignedPreKey;
@ -48,27 +48,18 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
class KeysAnonymousGrpcServiceTest { class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcService, KeysAnonymousGrpc.KeysAnonymousBlockingStub> {
@Mock
private AccountsManager accountsManager; private AccountsManager accountsManager;
@Mock
private KeysManager keysManager; private KeysManager keysManager;
private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub;
@RegisterExtension @Override
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); protected KeysAnonymousGrpcService createServiceBeforeEachTest() {
return new KeysAnonymousGrpcService(accountsManager, keysManager);
@BeforeEach
void setUp() {
accountsManager = mock(AccountsManager.class);
keysManager = mock(KeysManager.class);
final KeysAnonymousGrpcService keysGrpcService =
new KeysAnonymousGrpcService(accountsManager, keysManager);
keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService);
} }
@Test @Test
@ -101,7 +92,7 @@ class KeysAnonymousGrpcServiceTest {
when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey))); when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey); when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey);
final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() final GetPreKeysResponse response = unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.setRequest(GetPreKeysRequest.newBuilder() .setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
@ -152,19 +143,15 @@ class KeysAnonymousGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier))) when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final StatusRuntimeException statusRuntimeException = assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
assertThrows(StatusRuntimeException.class, .setRequest(GetPreKeysRequest.newBuilder()
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
.setRequest(GetPreKeysRequest.newBuilder() .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setUuid(UUIDUtil.toByteString(identifier))
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) .build())
.setUuid(UUIDUtil.toByteString(identifier)) .setDeviceId(Device.MASTER_ID)
.build()) .build())
.setDeviceId(Device.MASTER_ID) .build()));
.build())
.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
} }
@Test @Test
@ -174,7 +161,7 @@ class KeysAnonymousGrpcServiceTest {
final StatusRuntimeException exception = final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, assertThrows(StatusRuntimeException.class,
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID())) .setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID()))
.setRequest(GetPreKeysRequest.newBuilder() .setRequest(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
@ -205,19 +192,15 @@ class KeysAnonymousGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier))) when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final StatusRuntimeException exception = assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
assertThrows(StatusRuntimeException.class, .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() .setRequest(GetPreKeysRequest.newBuilder()
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) .setTargetIdentifier(ServiceIdentifier.newBuilder()
.setRequest(GetPreKeysRequest.newBuilder() .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setUuid(UUIDUtil.toByteString(accountIdentifier))
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) .build())
.setUuid(UUIDUtil.toByteString(accountIdentifier)) .setDeviceId(deviceId)
.build()) .build())
.setDeviceId(deviceId) .build()));
.build())
.build()));
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
} }
} }

View File

@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
@ -16,9 +15,10 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.ServerInterceptors;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import java.time.Duration; import java.time.Duration;
@ -32,14 +32,13 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey; import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey; import org.signal.chat.common.KemSignedPreKey;
@ -56,7 +55,6 @@ import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@ -73,41 +71,34 @@ import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
class KeysGrpcServiceTest { class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.KeysBlockingStub> {
private AccountsManager accountsManager;
private KeysManager keysManager;
private RateLimiter preKeysRateLimiter;
private Device authenticatedDevice;
private KeysGrpc.KeysBlockingStub keysStub;
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension protected static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
@BeforeEach @Mock
void setUp() { private AccountsManager accountsManager;
accountsManager = mock(AccountsManager.class);
keysManager = mock(KeysManager.class);
preKeysRateLimiter = mock(RateLimiter.class);
@Mock
private KeysManager keysManager;
@Mock
private RateLimiter preKeysRateLimiter;
@Mock
private Device authenticatedDevice;
@Override
protected KeysGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class); final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter); when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter);
when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters);
keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
authenticatedDevice = mock(Device.class);
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID); when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
final Account authenticatedAccount = mock(Account.class); final Account authenticatedAccount = mock(Account.class);
@ -119,17 +110,13 @@ class KeysGrpcServiceTest {
when(authenticatedAccount.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey())); when(authenticatedAccount.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice)); when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor));
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount)); when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount));
when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount)); when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount));
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
return new KeysGrpcService(accountsManager, keysManager, rateLimiters);
} }
@Test @Test
@ -152,7 +139,7 @@ class KeysGrpcServiceTest {
.setPniEcPreKeyCount(3) .setPniEcPreKeyCount(3)
.setPniKemPreKeyCount(4) .setPniKemPreKeyCount(4)
.build(), .build(),
keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build())); authenticatedServiceStub().getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
} }
@ParameterizedTest @ParameterizedTest
@ -168,7 +155,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder() authenticatedServiceStub().setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
.setIdentityType(identityType) .setIdentityType(identityType)
.addAllPreKeys(preKeys.stream() .addAllPreKeys(preKeys.stream()
.map(preKey -> EcPreKey.newBuilder() .map(preKey -> EcPreKey.newBuilder()
@ -189,10 +176,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) { void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) {
final StatusRuntimeException exception = assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeEcPreKeys(request));
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
} }
private static Stream<Arguments> setOneTimeEcPreKeysWithError() { private static Stream<Arguments> setOneTimeEcPreKeysWithError() {
@ -242,7 +226,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
keysStub.setOneTimeKemSignedPreKeys( authenticatedServiceStub().setOneTimeKemSignedPreKeys(
SetOneTimeKemSignedPreKeysRequest.newBuilder() SetOneTimeKemSignedPreKeysRequest.newBuilder()
.setIdentityType(identityType) .setIdentityType(identityType)
.addAllPreKeys(preKeys.stream() .addAllPreKeys(preKeys.stream()
@ -265,10 +249,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) { void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) {
final StatusRuntimeException exception = assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeKemSignedPreKeys(request));
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
} }
private static Stream<Arguments> setOneTimeKemSignedPreKeysWithError() { private static Stream<Arguments> setOneTimeKemSignedPreKeysWithError() {
@ -333,7 +314,7 @@ class KeysGrpcServiceTest {
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair);
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder() authenticatedServiceStub().setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
.setIdentityType(identityType) .setIdentityType(identityType)
.setSignedPreKey(EcSignedPreKey.newBuilder() .setSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(signedPreKey.keyId()) .setKeyId(signedPreKey.keyId())
@ -359,7 +340,7 @@ class KeysGrpcServiceTest {
@MethodSource @MethodSource
void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) { void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) {
final StatusRuntimeException exception = final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request)); assertThrows(StatusRuntimeException.class, () -> authenticatedServiceStub().setEcSignedPreKey(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
} }
@ -416,7 +397,7 @@ class KeysGrpcServiceTest {
final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair); final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair);
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder() authenticatedServiceStub().setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
.setIdentityType(identityType) .setIdentityType(identityType)
.setSignedPreKey(KemSignedPreKey.newBuilder() .setSignedPreKey(KemSignedPreKey.newBuilder()
.setKeyId(lastResortPreKey.keyId()) .setKeyId(lastResortPreKey.keyId())
@ -437,10 +418,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) { void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) {
final StatusRuntimeException exception = assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setKemLastResortPreKey(request));
assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
} }
private static Stream<Arguments> setLastResortPreKeyWithError() { private static Stream<Arguments> setLastResortPreKeyWithError() {
@ -528,7 +506,7 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
{ {
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder() final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(grpcIdentityType) .setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier)) .setUuid(UUIDUtil.toByteString(identifier))
@ -563,7 +541,7 @@ class KeysGrpcServiceTest {
when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
{ {
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder() final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(grpcIdentityType) .setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier)) .setUuid(UUIDUtil.toByteString(identifier))
@ -606,15 +584,12 @@ class KeysGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any())) when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty())); .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
final StatusRuntimeException exception = assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.setUuid(UUIDUtil.toByteString(UUID.randomUUID())) .build())
.build()) .build()));
.build()));
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
} }
@ParameterizedTest @ParameterizedTest
@ -631,16 +606,13 @@ class KeysGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier))) when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final StatusRuntimeException exception = assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) .setUuid(UUIDUtil.toByteString(accountIdentifier))
.setUuid(UUIDUtil.toByteString(accountIdentifier)) .build())
.build()) .setDeviceId(deviceId)
.setDeviceId(deviceId) .build()));
.build()));
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
} }
@Test @Test
@ -655,22 +627,15 @@ class KeysGrpcServiceTest {
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final Duration retryAfterDuration = Duration.ofMinutes(7); final Duration retryAfterDuration = Duration.ofMinutes(7);
when(preKeysRateLimiter.validateReactive(anyString())) when(preKeysRateLimiter.validateReactive(anyString()))
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false))); .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
final StatusRuntimeException exception = assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder() .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
.setUuid(UUIDUtil.toByteString(UUID.randomUUID())) .build())
.build()) .build()));
.build()));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(accountsManager); verifyNoInteractions(accountsManager);
} }
} }

View File

@ -1,19 +1,27 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc; package org.whispersystems.textsecuregcm.grpc;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException; import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.stub.MetadataUtils;
import java.lang.reflect.InvocationTargetException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.Instant; import java.time.Instant;
@ -24,14 +32,12 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream; import java.util.stream.Stream;
import io.grpc.StatusRuntimeException; import javax.annotation.Nullable;
import io.grpc.stub.MetadataUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.signal.chat.common.IdentityType; import org.signal.chat.common.IdentityType;
import org.signal.chat.common.ServiceIdentifier; import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.profile.CredentialType; import org.signal.chat.profile.CredentialType;
@ -72,43 +78,41 @@ import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import javax.annotation.Nullable;
public class ProfileAnonymousGrpcServiceTest { public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@Mock
private Account account; private Account account;
@Mock
private AccountsManager accountsManager; private AccountsManager accountsManager;
@Mock
private ProfilesManager profilesManager; private ProfilesManager profilesManager;
@Mock
private ProfileBadgeConverter profileBadgeConverter; private ProfileBadgeConverter profileBadgeConverter;
private ProfileAnonymousGrpc.ProfileAnonymousBlockingStub profileAnonymousBlockingStub;
@Mock
private ServerZkProfileOperations serverZkProfileOperations; private ServerZkProfileOperations serverZkProfileOperations;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); @Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
@BeforeEach return new ProfileAnonymousGrpcService(
void setup() {
account = mock(Account.class);
accountsManager = mock(AccountsManager.class);
profilesManager = mock(ProfilesManager.class);
profileBadgeConverter = mock(ProfileBadgeConverter.class);
serverZkProfileOperations = mock(ServerZkProfileOperations.class);
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
profileAnonymousBlockingStub = ProfileAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
final ProfileAnonymousGrpcService profileAnonymousGrpcService = new ProfileAnonymousGrpcService(
accountsManager, accountsManager,
profilesManager, profilesManager,
profileBadgeConverter, profileBadgeConverter,
serverZkProfileOperations serverZkProfileOperations
); );
}
GRPC_SERVER_EXTENSION.getServiceRegistry() @Override
.addService(profileAnonymousGrpcService); protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
} }
@Test @Test
@ -151,7 +155,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build()) .build())
.build(); .build();
final GetUnversionedProfileResponse response = profileAnonymousBlockingStub.getUnversionedProfile(request); final GetUnversionedProfileResponse response = unauthenticatedServiceStub().getUnversionedProfile(request);
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey); final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
final GetUnversionedProfileResponse expectedResponse = GetUnversionedProfileResponse.newBuilder() final GetUnversionedProfileResponse expectedResponse = GetUnversionedProfileResponse.newBuilder()
@ -189,10 +193,7 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)); requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
} }
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getUnversionedProfile(requestBuilder.build()));
() -> profileAnonymousBlockingStub.getUnversionedProfile(requestBuilder.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
} }
private static Stream<Arguments> getUnversionedProfileUnauthenticated() { private static Stream<Arguments> getUnversionedProfileUnauthenticated() {
@ -242,7 +243,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build()) .build())
.build(); .build();
final GetVersionedProfileResponse response = profileAnonymousBlockingStub.getVersionedProfile(request); final GetVersionedProfileResponse response = unauthenticatedServiceStub().getVersionedProfile(request);
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder() final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
.setName(ByteString.copyFrom(name)) .setName(ByteString.copyFrom(name))
@ -287,10 +288,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build()) .build())
.build(); .build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getVersionedProfile(request));
() -> profileAnonymousBlockingStub.getVersionedProfile(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
} }
@ParameterizedTest @ParameterizedTest
@ -318,18 +316,15 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)); requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
} }
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getVersionedProfile(requestBuilder.build()));
() -> profileAnonymousBlockingStub.getVersionedProfile(requestBuilder.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
} }
private static Stream<Arguments> getVersionedProfileUnauthenticated() { private static Stream<Arguments> getVersionedProfileUnauthenticated() {
return Stream.of( return Stream.of(
Arguments.of(true, false), Arguments.of(true, false),
Arguments.of(false, true) Arguments.of(false, true)
); );
} }
@Test @Test
void getVersionedProfilePniInvalidArgument() { void getVersionedProfilePniInvalidArgument() {
final byte[] unidentifiedAccessKey = new byte[16]; final byte[] unidentifiedAccessKey = new byte[16];
@ -346,10 +341,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build()) .build())
.build(); .build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getVersionedProfile(request));
() -> profileAnonymousBlockingStub.getVersionedProfile(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
} }
@Test @Test
@ -404,7 +396,7 @@ public class ProfileAnonymousGrpcServiceTest {
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
.build(); .build();
final GetExpiringProfileKeyCredentialResponse response = profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request); final GetExpiringProfileKeyCredentialResponse response = unauthenticatedServiceStub().getExpiringProfileKeyCredential(request);
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray()); assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
@ -442,10 +434,7 @@ public class ProfileAnonymousGrpcServiceTest {
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)); requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
} }
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(requestBuilder.build()));
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(requestBuilder.build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
verifyNoInteractions(profilesManager); verifyNoInteractions(profilesManager);
} }
@ -483,10 +472,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build()) .build())
.build(); .build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
} }
@ParameterizedTest @ParameterizedTest
@ -521,10 +507,7 @@ public class ProfileAnonymousGrpcServiceTest {
.build()) .build())
.build(); .build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
} }
private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() { private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() {

View File

@ -10,8 +10,6 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
@ -21,15 +19,14 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber; import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerInterceptors;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.MetadataUtils;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.Clock; import java.time.Clock;
@ -44,15 +41,14 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.signal.chat.common.IdentityType; import org.signal.chat.common.IdentityType;
import org.signal.chat.common.ServiceIdentifier; import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.profile.CredentialType; import org.signal.chat.profile.CredentialType;
@ -82,7 +78,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest;
import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext; import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext;
import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations; import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.badges.ProfileBadgeConverter; import org.whispersystems.textsecuregcm.badges.ProfileBadgeConverter;
import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
@ -100,49 +95,54 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
public class ProfileGrpcServiceTest { public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcService, ProfileGrpc.ProfileBlockingStub> {
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
private static final String S3_BUCKET = "profileBucket"; private static final String S3_BUCKET = "profileBucket";
private static final String VERSION = "someVersion"; private static final String VERSION = "someVersion";
private static final byte[] VALID_NAME = new byte[81]; private static final byte[] VALID_NAME = new byte[81];
@Mock
private AccountsManager accountsManager; private AccountsManager accountsManager;
@Mock
private ProfilesManager profilesManager; private ProfilesManager profilesManager;
@Mock
private DynamicPaymentsConfiguration dynamicPaymentsConfiguration; private DynamicPaymentsConfiguration dynamicPaymentsConfiguration;
@Mock
private S3AsyncClient asyncS3client; private S3AsyncClient asyncS3client;
@Mock
private VersionedProfile profile; private VersionedProfile profile;
@Mock
private Account account; private Account account;
@Mock
private RateLimiter rateLimiter; private RateLimiter rateLimiter;
@Mock
private ProfileBadgeConverter profileBadgeConverter; private ProfileBadgeConverter profileBadgeConverter;
@Mock
private ServerZkProfileOperations serverZkProfileOperations; private ServerZkProfileOperations serverZkProfileOperations;
private ProfileGrpc.ProfileBlockingStub profileBlockingStub;
@RegisterExtension
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
profilesManager = mock(ProfilesManager.class);
dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class);
asyncS3client = mock(S3AsyncClient.class);
profile = mock(VersionedProfile.class);
account = mock(Account.class);
rateLimiter = mock(RateLimiter.class);
profileBadgeConverter = mock(ProfileBadgeConverter.class);
serverZkProfileOperations = mock(ServerZkProfileOperations.class);
@Override
protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class); @SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
final PolicySigner policySigner = new PolicySigner("accessSecret", "us-west-1"); final PolicySigner policySigner = new PolicySigner("accessSecret", "us-west-1");
@ -170,30 +170,6 @@ public class ProfileGrpcServiceTest {
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us"); metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3"); metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
profileBlockingStub = ProfileGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
final ProfileGrpcService profileGrpcService = new ProfileGrpcService(
Clock.systemUTC(),
accountsManager,
profilesManager,
dynamicConfigurationManager,
badgesConfiguration,
asyncS3client,
policyGenerator,
policySigner,
profileBadgeConverter,
rateLimiters,
serverZkProfileOperations,
S3_BUCKET
);
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
GRPC_SERVER_EXTENSION.getServiceRegistry()
.addService(ServerInterceptors.intercept(profileGrpcService, mockAuthenticationInterceptor));
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter); when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty()); when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
@ -218,6 +194,21 @@ public class ProfileGrpcServiceTest {
when(dynamicPaymentsConfiguration.getDisallowedPrefixes()).thenReturn(Collections.emptyList()); when(dynamicPaymentsConfiguration.getDisallowedPrefixes()).thenReturn(Collections.emptyList());
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null)); when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
return new ProfileGrpcService(
Clock.systemUTC(),
accountsManager,
profilesManager,
dynamicConfigurationManager,
badgesConfiguration,
asyncS3client,
policyGenerator,
policySigner,
profileBadgeConverter,
rateLimiters,
serverZkProfileOperations,
S3_BUCKET
);
} }
@Test @Test
@ -237,7 +228,7 @@ public class ProfileGrpcServiceTest {
.setCommitment(ByteString.copyFrom(commitment)) .setCommitment(ByteString.copyFrom(commitment))
.build(); .build();
profileBlockingStub.setProfile(request); authenticatedServiceStub().setProfile(request);
final ArgumentCaptor<VersionedProfile> profileArgumentCaptor = ArgumentCaptor.forClass(VersionedProfile.class); final ArgumentCaptor<VersionedProfile> profileArgumentCaptor = ArgumentCaptor.forClass(VersionedProfile.class);
@ -274,7 +265,7 @@ public class ProfileGrpcServiceTest {
hasPreviousProfile ? Optional.of(profile) : Optional.empty())); hasPreviousProfile ? Optional.of(profile) : Optional.empty()));
when(profilesManager.setAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(profilesManager.setAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
SetProfileResponse response = profileBlockingStub.setProfile(request); final SetProfileResponse response = authenticatedServiceStub().setProfile(request);
if (expectHasS3UploadPath) { if (expectHasS3UploadPath) {
assertTrue(response.getAttributes().getPath().startsWith("profiles/")); assertTrue(response.getAttributes().getPath().startsWith("profiles/"));
@ -312,10 +303,7 @@ public class ProfileGrpcServiceTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void setProfileInvalidRequestData(final SetProfileRequest request) { void setProfileInvalidRequestData(final SetProfileRequest request) {
final StatusRuntimeException exception = assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setProfile(request));
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
} }
private static Stream<Arguments> setProfileInvalidRequestData() throws InvalidInputException{ private static Stream<Arguments> setProfileInvalidRequestData() throws InvalidInputException{
@ -386,12 +374,10 @@ public class ProfileGrpcServiceTest {
when(profilesManager.getAsync(any(), anyString())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); when(profilesManager.getAsync(any(), anyString())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
if (hasExistingPaymentAddress) { if (hasExistingPaymentAddress) {
assertDoesNotThrow(() -> profileBlockingStub.setProfile(request), assertDoesNotThrow(() -> authenticatedServiceStub().setProfile(request),
"Payment address changes in disallowed countries should still be allowed if the account already has a valid payment address"); "Payment address changes in disallowed countries should still be allowed if the account already has a valid payment address");
} else { } else {
final StatusRuntimeException exception = assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().setProfile(request));
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
assertEquals(Status.PERMISSION_DENIED.getCode(), exception.getStatus().getCode());
} }
} }
@ -433,7 +419,7 @@ public class ProfileGrpcServiceTest {
when(profileBadgeConverter.convert(any(), any(), anyBoolean())).thenReturn(badges); when(profileBadgeConverter.convert(any(), any(), anyBoolean())).thenReturn(badges);
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
final GetUnversionedProfileResponse response = profileBlockingStub.getUnversionedProfile(request); final GetUnversionedProfileResponse response = authenticatedServiceStub().getUnversionedProfile(request);
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey); final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
final GetUnversionedProfileResponse prototypeExpectedResponse = GetUnversionedProfileResponse.newBuilder() final GetUnversionedProfileResponse prototypeExpectedResponse = GetUnversionedProfileResponse.newBuilder()
@ -472,10 +458,7 @@ public class ProfileGrpcServiceTest {
.build()) .build())
.build(); .build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getUnversionedProfile(request));
() -> profileBlockingStub.getUnversionedProfile(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
} }
@ParameterizedTest @ParameterizedTest
@ -493,14 +476,7 @@ public class ProfileGrpcServiceTest {
.build()) .build())
.build(); .build();
final StatusRuntimeException exception = assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getUnversionedProfile(request), accountsManager);
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.getUnversionedProfile(request));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(accountsManager);
} }
@ParameterizedTest @ParameterizedTest
@ -531,7 +507,7 @@ public class ProfileGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
final GetVersionedProfileResponse response = profileBlockingStub.getVersionedProfile(request); final GetVersionedProfileResponse response = authenticatedServiceStub().getVersionedProfile(request);
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder() final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
.setName(ByteString.copyFrom(name)) .setName(ByteString.copyFrom(name))
@ -545,7 +521,6 @@ public class ProfileGrpcServiceTest {
assertEquals(expectedResponseBuilder.build(), response); assertEquals(expectedResponseBuilder.build(), response);
} }
private static Stream<Arguments> getVersionedProfile() { private static Stream<Arguments> getVersionedProfile() {
return Stream.of( return Stream.of(
Arguments.of("version1", "version1", true), Arguments.of("version1", "version1", true),
@ -553,6 +528,7 @@ public class ProfileGrpcServiceTest {
Arguments.of("version1", "version2", false) Arguments.of("version1", "version2", false)
); );
} }
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void getVersionedProfileAccountOrProfileNotFound(final boolean missingAccount, final boolean missingProfile) { void getVersionedProfileAccountOrProfileNotFound(final boolean missingAccount, final boolean missingProfile) {
@ -566,10 +542,7 @@ public class ProfileGrpcServiceTest {
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(missingAccount ? Optional.empty() : Optional.of(account))); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(missingAccount ? Optional.empty() : Optional.of(account)));
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(missingProfile ? Optional.empty() : Optional.of(profile))); when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(missingProfile ? Optional.empty() : Optional.of(profile)));
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getVersionedProfile(request));
() -> profileBlockingStub.getVersionedProfile(request));
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
} }
private static Stream<Arguments> getVersionedProfileAccountOrProfileNotFound() { private static Stream<Arguments> getVersionedProfileAccountOrProfileNotFound() {
@ -581,10 +554,7 @@ public class ProfileGrpcServiceTest {
@Test @Test
void getVersionedProfileRatelimited() { void getVersionedProfileRatelimited() {
final Duration retryAfterDuration = Duration.ofMinutes(7); final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(7), false);
when(rateLimiter.validateReactive(any(UUID.class)))
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
final GetVersionedProfileRequest request = GetVersionedProfileRequest.newBuilder() final GetVersionedProfileRequest request = GetVersionedProfileRequest.newBuilder()
.setAccountIdentifier(ServiceIdentifier.newBuilder() .setAccountIdentifier(ServiceIdentifier.newBuilder()
@ -594,15 +564,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion") .setVersion("someVersion")
.build(); .build();
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getVersionedProfile(request), accountsManager, profilesManager);
() -> profileBlockingStub.getVersionedProfile(request));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(accountsManager);
verifyNoInteractions(profilesManager);
} }
@Test @Test
@ -615,9 +577,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion") .setVersion("someVersion")
.build(); .build();
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getVersionedProfile(request));
() -> profileBlockingStub.getVersionedProfile(request));
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
} }
@Test @Test
@ -664,7 +624,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion") .setVersion("someVersion")
.build(); .build();
final GetExpiringProfileKeyCredentialResponse response = profileBlockingStub.getExpiringProfileKeyCredential(request); final GetExpiringProfileKeyCredentialResponse response = authenticatedServiceStub().getExpiringProfileKeyCredential(request);
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray()); assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
@ -677,9 +637,8 @@ public class ProfileGrpcServiceTest {
@Test @Test
void getExpiringProfileKeyCredentialRateLimited() { void getExpiringProfileKeyCredentialRateLimited() {
final Duration retryAfterDuration = Duration.ofMinutes(5); final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(
when(rateLimiter.validateReactive(AUTHENTICATED_ACI)) rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(5), false);
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
final GetExpiringProfileKeyCredentialRequest request = GetExpiringProfileKeyCredentialRequest.newBuilder() final GetExpiringProfileKeyCredentialRequest request = GetExpiringProfileKeyCredentialRequest.newBuilder()
@ -692,14 +651,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion") .setVersion("someVersion")
.build(); .build();
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request), profilesManager);
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
verifyNoInteractions(profilesManager);
} }
@ParameterizedTest @ParameterizedTest
@ -723,10 +675,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion") .setVersion("someVersion")
.build(); .build();
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class, assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.Code.NOT_FOUND, statusRuntimeException.getStatus().getCode());
} }
private static Stream<Arguments> getExpiringProfileKeyCredentialAccountOrProfileNotFound() { private static Stream<Arguments> getExpiringProfileKeyCredentialAccountOrProfileNotFound() {
@ -761,10 +710,7 @@ public class ProfileGrpcServiceTest {
.setVersion("someVersion") .setVersion("someVersion")
.build(); .build();
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class, assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
} }
private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() { private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() {

View File

@ -0,0 +1,146 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static java.util.Objects.requireNonNull;
import io.grpc.BindableService;
import io.grpc.Channel;
import io.grpc.stub.AbstractBlockingStub;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.UUID;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.MockitoAnnotations;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.storage.Device;
/**
* Base class for the common case of gRPC services tests. This base class makes some assumptions
* and introduces some constraints on the implementing classes with a goal of simplifying the process
* of creating a test for the most of the gRPC services.
* <ul>
* <li>
* Test classes extending this class will have to override the {@link #createServiceBeforeEachTest()} method
* with the logic that creates an instance of the service to test. This method is called before each test and should
* contain other setup code that would normally go into {@code @BeforeEach} method.
* </li>
* <li>
* This base class takes care of creating two service stubs: {@code authenticatedServiceStub} and {@code unauthenticatedServiceStub}.
* Normally, those stubs are created by the call to the {@code newBlockingStub()} method on the {@code *Stub} class, e.g.:
* <pre>CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());</pre>
* In this class, those stubs are created by the {@link #createStub(Channel)} method that has a default implementation that is based on
* figuring out the name of the {@code `*Grpc`} class and invoking {@code `*Grpc.newBlockingStub()`} method with reflection.
* </li>
* <li>
* This class takes care of initializing {@code Mockito} annotations processing, so implementing classes
* can annotate their fields with {@code @Mock} and have those mocks ready by the time {@link #createServiceBeforeEachTest()} is called.
* </li>
* </ul>
* @param <SERVICE> Class of the gRPC service that is being tested.
* @param <STUB> Class of the gRPC service stub.
*/
public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB extends AbstractBlockingStub<STUB>> {
@RegisterExtension
protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_AUTHENTICATED = new GrpcServerExtension();
@RegisterExtension
protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_UNAUTHENTICATED = new GrpcServerExtension();
protected static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
protected static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
private AutoCloseable mocksCloseable;
private MockAuthenticationInterceptor mockAuthenticationInterceptor;
private SERVICE service;
private STUB authenticatedServiceStub;
private STUB unauthenticatedServiceStub;
/**
* This method is invoked before each test and is expected to create an instance of the gRPC service
* that is being tested and also to perform all necessary before-each setup.
* </p>
* Extending classes may have their own {@code @BeforeEach} method, but it will be called after this method.
* @return an instance of the gRPC service.
*/
protected abstract SERVICE createServiceBeforeEachTest();
/**
* The default implementation of this method is based on figuring out the name of the {@code `*Grpc`} class
* and invoking {@code `*Grpc.newBlockingStub()`} method with reflection.
* <p>
* Overriding this method can be helpful if addutional configuration of the stub is required, e.g. adding interceptors:
* <pre>
* protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
* final Metadata metadata = new Metadata();
* metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
* metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
* return super.createStub(channel)
* .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
* }
* </pre>
* @param channel grpc channel to create create the stub for.
* @return and instance of the service stub.
*/
protected STUB createStub(final Channel channel) throws
ClassNotFoundException,
NoSuchMethodException,
InvocationTargetException,
IllegalAccessException {
final String serviceClassName = service.bindService().getServiceDescriptor().getName();
final String grpcClassName = serviceClassName + "Grpc";
final Class<?> grpcClass = ClassLoader.getSystemClassLoader().loadClass(grpcClassName);
final Method newBlockingStubMethod = grpcClass.getMethod("newBlockingStub", Channel.class);
final Object stub = newBlockingStubMethod.invoke(null, channel);
//noinspection unchecked
return (STUB) stub;
}
@BeforeEach
protected void baseSetup() {
mocksCloseable = MockitoAnnotations.openMocks(this);
service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`");
mockAuthenticationInterceptor = GrpcTestUtils.setupAuthenticatedExtension(
GRPC_SERVER_EXTENSION_AUTHENTICATED, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, service);
try {
authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel());
} catch (Exception e) {
throw new RuntimeException("Could not create a stub based on the service name. Try overriding `createStub()` method.");
}
}
@AfterEach
public void releaseMocks() throws Exception {
mocksCloseable.close();
}
public MockAuthenticationInterceptor mockAuthenticationInterceptor() {
return mockAuthenticationInterceptor;
}
protected SERVICE service() {
return service;
}
protected STUB authenticatedServiceStub() {
return authenticatedServiceStub;
}
protected STUB unauthenticatedServiceStub() {
return unauthenticatedServiceStub;
}
}

View File

@ -111,7 +111,7 @@ public class RateLimitersTest {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig(); final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config()); assertEquals(expected, config(limiter));
} }
@Test @Test
@ -131,16 +131,16 @@ public class RateLimitersTest {
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig); limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
assertEquals(initialRateLimiterConfig, limiter.config()); assertEquals(initialRateLimiterConfig, config(limiter));
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config()); assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config()); assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig); limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
assertEquals(updatedRateLimiterCongig, limiter.config()); assertEquals(updatedRateLimiterCongig, config(limiter));
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config()); assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config()); assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
} }
@Test @Test
@ -161,22 +161,22 @@ public class RateLimitersTest {
// test only default is present // test only default is present
mapForDynamic.remove(descriptor.id()); mapForDynamic.remove(descriptor.id());
mapForStatic.remove(descriptor.id()); mapForStatic.remove(descriptor.id());
assertEquals(defaultConfig, limiter.config()); assertEquals(defaultConfig, config(limiter));
// test dynamic and no static // test dynamic and no static
mapForDynamic.put(descriptor.id(), configForDynamic); mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.remove(descriptor.id()); mapForStatic.remove(descriptor.id());
assertEquals(configForDynamic, limiter.config()); assertEquals(configForDynamic, config(limiter));
// test dynamic and static // test dynamic and static
mapForDynamic.put(descriptor.id(), configForDynamic); mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.put(descriptor.id(), configForStatic); mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForDynamic, limiter.config()); assertEquals(configForDynamic, config(limiter));
// test static, but no dynamic // test static, but no dynamic
mapForDynamic.remove(descriptor.id()); mapForDynamic.remove(descriptor.id());
mapForStatic.put(descriptor.id(), configForStatic); mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForStatic, limiter.config()); assertEquals(configForStatic, config(limiter));
} }
private record TestDescriptor(String id) implements RateLimiterDescriptor { private record TestDescriptor(String id) implements RateLimiterDescriptor {
@ -191,4 +191,14 @@ public class RateLimitersTest {
return new RateLimiterConfig(1, Duration.ofMinutes(1)); return new RateLimiterConfig(1, Duration.ofMinutes(1));
} }
} }
private static RateLimiterConfig config(final RateLimiter rateLimiter) {
if (rateLimiter instanceof StaticRateLimiter rm) {
return rm.config();
}
if (rateLimiter instanceof DynamicRateLimiter rm) {
return rm.config();
}
throw new IllegalArgumentException("Rate limiter is of an unexpected type: " + rateLimiter.getClass().getName());
}
} }

View File

@ -12,12 +12,14 @@ import static org.mockito.Mockito.doThrow;
import java.time.Duration; import java.time.Duration;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.RandomUtils;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes; import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import reactor.core.publisher.Mono;
public final class MockUtils { public final class MockUtils {
@ -46,32 +48,80 @@ public final class MockUtils {
} }
public static void updateRateLimiterResponseToAllow( public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock, final RateLimiter mockRateLimiter,
final RateLimiters.For handle,
final String input) { final String input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try { try {
doNothing().when(mockRateLimiter).validate(eq(input)); doNothing().when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static void updateRateLimiterResponseToAllow(
final RateLimiter mockRateLimiter,
final UUID input) {
try {
doNothing().when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final String input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
updateRateLimiterResponseToAllow(mockRateLimiter, input);
}
public static void updateRateLimiterResponseToAllow( public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock, final RateLimiters rateLimitersMock,
final RateLimiters.For handle, final RateLimiters.For handle,
final UUID input) { final UUID input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
updateRateLimiterResponseToAllow(mockRateLimiter, input);
}
public static Duration updateRateLimiterResponseToFail(
final RateLimiter mockRateLimiter,
final String input,
final Duration retryAfter,
final boolean legacyStatusCode) {
try { try {
doNothing().when(mockRateLimiter).validate(eq(input)); final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
doThrow(exception).when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
return retryAfter;
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
public static void updateRateLimiterResponseToFail( public static Duration updateRateLimiterResponseToFail(
final RateLimiter mockRateLimiter,
final UUID input,
final Duration retryAfter,
final boolean legacyStatusCode) {
try {
final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
doThrow(exception).when(mockRateLimiter).validate(eq(input));
doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
return retryAfter;
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static Duration updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock, final RateLimiters rateLimitersMock,
final RateLimiters.For handle, final RateLimiters.For handle,
final String input, final String input,
@ -79,14 +129,10 @@ public final class MockUtils {
final boolean legacyStatusCode) { final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try { return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
} }
public static void updateRateLimiterResponseToFail( public static Duration updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock, final RateLimiters rateLimitersMock,
final RateLimiters.For handle, final RateLimiters.For handle,
final UUID input, final UUID input,
@ -94,11 +140,7 @@ public final class MockUtils {
final boolean legacyStatusCode) { final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try { return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
} }
public static SecretBytes randomSecretBytes(final int size) { public static SecretBytes randomSecretBytes(final int size) {