DRY gRPC tests, refactor error mapping
This commit is contained in:
parent
29ca544c95
commit
977243ebfd
|
@ -120,6 +120,7 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
|
|||
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
|
||||
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
|
||||
import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcServerManagedWrapper;
|
||||
import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
|
||||
|
@ -644,8 +645,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
new BasicCredentialAuthenticationInterceptor(new BaseAccountAuthenticator(accountsManager));
|
||||
|
||||
final ServerBuilder<?> grpcServer = ServerBuilder.forPort(config.getGrpcPort())
|
||||
// TODO: specialize metrics with user-agent platform
|
||||
.intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
|
||||
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor))
|
||||
.addService(new KeysAnonymousGrpcService(accountsManager, keys))
|
||||
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
|
||||
|
@ -657,13 +656,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
.addFilter("RemoteDeprecationFilter", remoteDeprecationFilter)
|
||||
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
||||
|
||||
grpcServer.intercept(new AcceptLanguageInterceptor());
|
||||
|
||||
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
||||
// depends on the user-agent context so it has to come first here!
|
||||
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
|
||||
grpcServer.intercept(remoteDeprecationFilter);
|
||||
grpcServer.intercept(new UserAgentInterceptor());
|
||||
grpcServer
|
||||
// TODO: specialize metrics with user-agent platform
|
||||
.intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
|
||||
.intercept(new ErrorMappingInterceptor())
|
||||
.intercept(new AcceptLanguageInterceptor())
|
||||
.intercept(remoteDeprecationFilter)
|
||||
.intercept(new UserAgentInterceptor());
|
||||
|
||||
environment.lifecycle().manage(new GrpcServerManagedWrapper(grpcServer.build()));
|
||||
|
||||
|
|
|
@ -1,14 +1,30 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* Copyright 2013 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
import java.time.Duration;
|
||||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.grpc.ConvertibleToGrpcStatus;
|
||||
|
||||
public class RateLimitExceededException extends Exception {
|
||||
public class RateLimitExceededException extends Exception implements ConvertibleToGrpcStatus {
|
||||
|
||||
public static final Metadata.Key<Duration> RETRY_AFTER_DURATION_KEY =
|
||||
Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
|
||||
@Override
|
||||
public String toAsciiString(final Duration value) {
|
||||
return value.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Duration parseAsciiString(final String serialized) {
|
||||
return Duration.parse(serialized);
|
||||
}
|
||||
});
|
||||
|
||||
@Nullable
|
||||
private final Duration retryDuration;
|
||||
|
@ -33,4 +49,19 @@ public class RateLimitExceededException extends Exception {
|
|||
public boolean isLegacy() {
|
||||
return legacy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status grpcStatus() {
|
||||
return Status.RESOURCE_EXHAUSTED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<Metadata> grpcMetadata() {
|
||||
return getRetryDuration()
|
||||
.map(duration -> {
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(RETRY_AFTER_DURATION_KEY, duration);
|
||||
return metadata;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,11 +24,6 @@ public class CallingGrpcService extends ReactorCallingGrpc.CallingImplBase {
|
|||
this.rateLimiters = rateLimiters;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Throwable onErrorMap(final Throwable throwable) {
|
||||
return RateLimitUtil.mapRateLimitExceededException(throwable);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<GetTurnCredentialsResponse> getTurnCredentials(final GetTurnCredentialsRequest request) {
|
||||
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Interface to be imlemented by our custom exceptions that are consistently mapped to a gRPC status.
|
||||
*/
|
||||
public interface ConvertibleToGrpcStatus {
|
||||
|
||||
Status grpcStatus();
|
||||
|
||||
Optional<Metadata> grpcMetadata();
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.ForwardingServerCall;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
|
||||
/**
|
||||
* This interceptor observes responses from the service and if the response status is {@link Status#UNKNOWN}
|
||||
* and there is a non-null cause which is an instance of {@link ConvertibleToGrpcStatus},
|
||||
* then status and metadata to be returned to the client is resolved from that object.
|
||||
* </p>
|
||||
* This eliminates the need of having each service to override {@code `onErrorMap()`} method for commonly used exceptions.
|
||||
*/
|
||||
public class ErrorMappingInterceptor implements ServerInterceptor {
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
|
||||
final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
return next.startCall(new ForwardingServerCall.SimpleForwardingServerCall<>(call) {
|
||||
@Override
|
||||
public void close(final Status status, final Metadata trailers) {
|
||||
// The idea is to only apply the automatic conversion logic in the cases
|
||||
// when there was no explicit decision by the service to provide a status.
|
||||
// I.e. if at this point we see anything but the `UNKNOWN`,
|
||||
// that means that some logic in the service made this decision already
|
||||
// and automatic conversion may conflict with it.
|
||||
if (status.getCode().equals(Status.Code.UNKNOWN)
|
||||
&& status.getCause() instanceof ConvertibleToGrpcStatus convertibleToGrpcStatus) {
|
||||
super.close(
|
||||
convertibleToGrpcStatus.grpcStatus(),
|
||||
convertibleToGrpcStatus.grpcMetadata().orElseGet(Metadata::new)
|
||||
);
|
||||
} else {
|
||||
super.close(status, trailers);
|
||||
}
|
||||
}
|
||||
}, headers);
|
||||
}
|
||||
}
|
|
@ -74,11 +74,6 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
|
|||
this.rateLimiters = rateLimiters;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Throwable onErrorMap(final Throwable throwable) {
|
||||
return RateLimitUtil.mapRateLimitExceededException(throwable);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<GetPreKeyCountResponse> getPreKeyCount(final GetPreKeyCountRequest request) {
|
||||
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
|
||||
|
|
|
@ -100,11 +100,6 @@ public class ProfileGrpcService extends ReactorProfileGrpc.ProfileImplBase {
|
|||
this.bucket = bucket;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Throwable onErrorMap(final Throwable throwable) {
|
||||
return RateLimitUtil.mapRateLimitExceededException(throwable);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<SetProfileResponse> setProfile(final SetProfileRequest request) {
|
||||
validateRequest(request);
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusException;
|
||||
import java.time.Duration;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
|
||||
public class RateLimitUtil {
|
||||
|
||||
public static final Metadata.Key<Duration> RETRY_AFTER_DURATION_KEY =
|
||||
Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
|
||||
@Override
|
||||
public String toAsciiString(final Duration value) {
|
||||
return value.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Duration parseAsciiString(final String serialized) {
|
||||
return Duration.parse(serialized);
|
||||
}
|
||||
});
|
||||
|
||||
public static Throwable mapRateLimitExceededException(final Throwable throwable) {
|
||||
if (throwable instanceof RateLimitExceededException rateLimitExceededException) {
|
||||
@Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration()
|
||||
.map(duration -> {
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(RETRY_AFTER_DURATION_KEY, duration);
|
||||
|
||||
return metadata;
|
||||
}).orElse(null);
|
||||
|
||||
return new StatusException(Status.RESOURCE_EXHAUSTED, trailers);
|
||||
}
|
||||
|
||||
return throwable;
|
||||
}
|
||||
}
|
|
@ -6,66 +6,41 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
|
||||
|
||||
import io.grpc.ServerInterceptors;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.mockito.Mock;
|
||||
import org.signal.chat.calling.CallingGrpc;
|
||||
import org.signal.chat.calling.GetTurnCredentialsRequest;
|
||||
import org.signal.chat.calling.GetTurnCredentialsResponse;
|
||||
import org.whispersystems.textsecuregcm.auth.TurnToken;
|
||||
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import reactor.core.publisher.Mono;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
|
||||
class CallingGrpcServiceTest {
|
||||
class CallingGrpcServiceTest extends SimpleBaseGrpcTest<CallingGrpcService, CallingGrpc.CallingBlockingStub> {
|
||||
|
||||
@Mock
|
||||
private TurnTokenGenerator turnTokenGenerator;
|
||||
|
||||
@Mock
|
||||
private RateLimiter turnCredentialRateLimiter;
|
||||
|
||||
private CallingGrpc.CallingBlockingStub callingStub;
|
||||
|
||||
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
|
||||
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
turnTokenGenerator = mock(TurnTokenGenerator.class);
|
||||
turnCredentialRateLimiter = mock(RateLimiter.class);
|
||||
|
||||
@Override
|
||||
protected CallingGrpcService createServiceBeforeEachTest() {
|
||||
final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||
when(rateLimiters.getTurnLimiter()).thenReturn(turnCredentialRateLimiter);
|
||||
|
||||
final CallingGrpcService callingGrpcService = new CallingGrpcService(turnTokenGenerator, rateLimiters);
|
||||
|
||||
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(callingGrpcService, mockAuthenticationInterceptor));
|
||||
|
||||
callingStub = CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
|
||||
return new CallingGrpcService(turnTokenGenerator, rateLimiters);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -74,10 +49,10 @@ class CallingGrpcServiceTest {
|
|||
final String password = "test-password";
|
||||
final List<String> urls = List.of("first", "second");
|
||||
|
||||
when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)).thenReturn(Mono.empty());
|
||||
MockUtils.updateRateLimiterResponseToAllow(turnCredentialRateLimiter, AUTHENTICATED_ACI);
|
||||
when(turnTokenGenerator.generate(any())).thenReturn(new TurnToken(username, password, urls));
|
||||
|
||||
final GetTurnCredentialsResponse response = callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build());
|
||||
final GetTurnCredentialsResponse response = authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build());
|
||||
|
||||
final GetTurnCredentialsResponse expectedResponse = GetTurnCredentialsResponse.newBuilder()
|
||||
.setUsername(username)
|
||||
|
@ -90,20 +65,10 @@ class CallingGrpcServiceTest {
|
|||
|
||||
@Test
|
||||
void getTurnCredentialsRateLimited() {
|
||||
final Duration retryAfter = Duration.ofMinutes(19);
|
||||
|
||||
when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI))
|
||||
.thenReturn(Mono.error(new RateLimitExceededException(retryAfter, false)));
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
|
||||
|
||||
final Duration retryAfter = MockUtils.updateRateLimiterResponseToFail(
|
||||
turnCredentialRateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(19), false);
|
||||
assertRateLimitExceeded(retryAfter, () -> authenticatedServiceStub().getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()));
|
||||
verify(turnTokenGenerator, never()).generate(any());
|
||||
|
||||
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
|
||||
assertNotNull(exception.getTrailers());
|
||||
assertEquals(retryAfter, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
|
||||
|
||||
verifyNoInteractions(turnTokenGenerator);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
@ -14,11 +13,10 @@ import static org.mockito.Mockito.never;
|
|||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.ServerInterceptors;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
|
@ -26,21 +24,19 @@ import java.time.temporal.ChronoUnit;
|
|||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.mockito.Mock;
|
||||
import org.signal.chat.device.ClearPushTokenRequest;
|
||||
import org.signal.chat.device.ClearPushTokenResponse;
|
||||
import org.signal.chat.device.DevicesGrpc;
|
||||
|
@ -54,42 +50,31 @@ import org.signal.chat.device.SetDeviceNameRequest;
|
|||
import org.signal.chat.device.SetDeviceNameResponse;
|
||||
import org.signal.chat.device.SetPushTokenRequest;
|
||||
import org.signal.chat.device.SetPushTokenResponse;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.KeysManager;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
|
||||
class DevicesGrpcServiceTest {
|
||||
class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, DevicesGrpc.DevicesBlockingStub> {
|
||||
|
||||
@Mock
|
||||
private AccountsManager accountsManager;
|
||||
|
||||
@Mock
|
||||
private KeysManager keysManager;
|
||||
|
||||
@Mock
|
||||
private MessagesManager messagesManager;
|
||||
|
||||
@Mock
|
||||
private Account authenticatedAccount;
|
||||
|
||||
private MockAuthenticationInterceptor mockAuthenticationInterceptor;
|
||||
private DevicesGrpc.DevicesBlockingStub devicesStub;
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
|
||||
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
|
||||
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
keysManager = mock(KeysManager.class);
|
||||
messagesManager = mock(MessagesManager.class);
|
||||
|
||||
authenticatedAccount = mock(Account.class);
|
||||
@Override
|
||||
protected DevicesGrpcService createServiceBeforeEachTest() {
|
||||
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
|
||||
|
||||
mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
|
||||
|
||||
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
|
||||
|
||||
|
@ -117,11 +102,7 @@ class DevicesGrpcServiceTest {
|
|||
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
|
||||
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
final DevicesGrpcService devicesGrpcService = new DevicesGrpcService(accountsManager, keysManager, messagesManager);
|
||||
devicesStub = DevicesGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(devicesGrpcService, mockAuthenticationInterceptor));
|
||||
return new DevicesGrpcService(accountsManager, keysManager, messagesManager);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -161,14 +142,14 @@ class DevicesGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
assertEquals(expectedResponse, devicesStub.getDevices(GetDevicesRequest.newBuilder().build()));
|
||||
assertEquals(expectedResponse, authenticatedServiceStub().getDevices(GetDevicesRequest.newBuilder().build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void removeDevice() {
|
||||
final long deviceId = 17;
|
||||
|
||||
final RemoveDeviceResponse ignored = devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
|
||||
final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
|
||||
.setId(deviceId)
|
||||
.build());
|
||||
|
||||
|
@ -179,30 +160,23 @@ class DevicesGrpcServiceTest {
|
|||
|
||||
@Test
|
||||
void removeDevicePrimary() {
|
||||
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
|
||||
() -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
|
||||
.setId(1)
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
|
||||
.setId(1)
|
||||
.build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void removeDeviceNonPrimaryAuthenticated() {
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1);
|
||||
|
||||
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
|
||||
() -> devicesStub.removeDevice(RemoveDeviceRequest.newBuilder()
|
||||
.setId(17)
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.PERMISSION_DENIED, exception.getStatus().getCode());
|
||||
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.MASTER_ID + 1);
|
||||
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
|
||||
.setId(17)
|
||||
.build()));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(longs = {Device.MASTER_ID, Device.MASTER_ID + 1})
|
||||
void setDeviceName(final long deviceId) {
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
|
@ -210,7 +184,7 @@ class DevicesGrpcServiceTest {
|
|||
final byte[] deviceName = new byte[128];
|
||||
ThreadLocalRandom.current().nextBytes(deviceName);
|
||||
|
||||
final SetDeviceNameResponse ignored = devicesStub.setDeviceName(SetDeviceNameRequest.newBuilder()
|
||||
final SetDeviceNameResponse ignored = authenticatedServiceStub().setDeviceName(SetDeviceNameRequest.newBuilder()
|
||||
.setName(ByteString.copyFrom(deviceName))
|
||||
.build());
|
||||
|
||||
|
@ -221,11 +195,7 @@ class DevicesGrpcServiceTest {
|
|||
@MethodSource
|
||||
void setDeviceNameIllegalArgument(final SetDeviceNameRequest request) {
|
||||
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(mock(Device.class)));
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> devicesStub.setDeviceName(request));
|
||||
|
||||
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setDeviceName(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setDeviceNameIllegalArgument() {
|
||||
|
@ -248,12 +218,12 @@ class DevicesGrpcServiceTest {
|
|||
@Nullable final String expectedApnsVoipToken,
|
||||
@Nullable final String expectedFcmToken) {
|
||||
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
|
||||
final SetPushTokenResponse ignored = devicesStub.setPushToken(request);
|
||||
final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
|
||||
|
||||
verify(device).setApnId(expectedApnsToken);
|
||||
verify(device).setVoipApnId(expectedApnsVoipToken);
|
||||
|
@ -312,7 +282,7 @@ class DevicesGrpcServiceTest {
|
|||
|
||||
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
|
||||
|
||||
final SetPushTokenResponse ignored = devicesStub.setPushToken(request);
|
||||
final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
|
||||
|
||||
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
|
||||
}
|
||||
|
@ -352,12 +322,7 @@ class DevicesGrpcServiceTest {
|
|||
void setPushTokenIllegalArgument(final SetPushTokenRequest request) {
|
||||
final Device device = mock(Device.class);
|
||||
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> devicesStub.setPushToken(request));
|
||||
|
||||
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
|
||||
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setPushToken(request));
|
||||
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
|
||||
}
|
||||
|
||||
|
@ -383,7 +348,7 @@ class DevicesGrpcServiceTest {
|
|||
@Nullable final String fcmToken,
|
||||
@Nullable final String expectedUserAgent) {
|
||||
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
|
@ -393,7 +358,7 @@ class DevicesGrpcServiceTest {
|
|||
when(device.getGcmId()).thenReturn(fcmToken);
|
||||
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
|
||||
final ClearPushTokenResponse ignored = devicesStub.clearPushToken(ClearPushTokenRequest.newBuilder().build());
|
||||
final ClearPushTokenResponse ignored = authenticatedServiceStub().clearPushToken(ClearPushTokenRequest.newBuilder().build());
|
||||
|
||||
verify(device).setApnId(null);
|
||||
verify(device).setVoipApnId(null);
|
||||
|
@ -430,12 +395,12 @@ class DevicesGrpcServiceTest {
|
|||
@CartesianTest.Values(booleans = {true, false}) final boolean pni,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean paymentActivation) {
|
||||
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(authenticatedAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
|
||||
final SetCapabilitiesResponse ignored = devicesStub.setCapabilities(SetCapabilitiesRequest.newBuilder()
|
||||
final SetCapabilitiesResponse ignored = authenticatedServiceStub().setCapabilities(SetCapabilitiesRequest.newBuilder()
|
||||
.setStorage(storage)
|
||||
.setTransfer(transfer)
|
||||
.setPni(pni)
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
|
||||
import io.grpc.BindableService;
|
||||
import io.grpc.ServerInterceptors;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.time.Duration;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.function.Executable;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
|
||||
public final class GrpcTestUtils {
|
||||
|
||||
private GrpcTestUtils() {
|
||||
// noop
|
||||
}
|
||||
|
||||
public static MockAuthenticationInterceptor setupAuthenticatedExtension(
|
||||
final GrpcServerExtension extension,
|
||||
final UUID authenticatedAci,
|
||||
final long authenticatedDeviceId,
|
||||
final BindableService service) {
|
||||
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
|
||||
extension.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(service, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
|
||||
return mockAuthenticationInterceptor;
|
||||
}
|
||||
|
||||
public static void setupUnauthenticatedExtension(
|
||||
final GrpcServerExtension extension,
|
||||
final BindableService service) {
|
||||
extension.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(service, new ErrorMappingInterceptor()));
|
||||
}
|
||||
|
||||
public static void assertStatusException(final Status expected, final Executable serviceCall) {
|
||||
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
|
||||
assertEquals(expected.getCode(), exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
public static void assertRateLimitExceeded(
|
||||
final Duration expectedRetryAfter,
|
||||
final Executable serviceCall,
|
||||
final Object... mocksToCheckForNoInteraction) {
|
||||
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
|
||||
assertEquals(Status.RESOURCE_EXHAUSTED, exception.getStatus());
|
||||
assertNotNull(exception.getTrailers());
|
||||
assertEquals(expectedRetryAfter, exception.getTrailers().get(RateLimitExceededException.RETRY_AFTER_DURATION_KEY));
|
||||
for (final Object mock: mocksToCheckForNoInteraction) {
|
||||
verifyNoInteractions(mock);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.any;
|
|||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
|
@ -20,11 +21,10 @@ import java.util.Collections;
|
|||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.mockito.Mock;
|
||||
import org.signal.chat.common.EcPreKey;
|
||||
import org.signal.chat.common.EcSignedPreKey;
|
||||
import org.signal.chat.common.KemSignedPreKey;
|
||||
|
@ -48,27 +48,18 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
|
|||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class KeysAnonymousGrpcServiceTest {
|
||||
class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcService, KeysAnonymousGrpc.KeysAnonymousBlockingStub> {
|
||||
|
||||
@Mock
|
||||
private AccountsManager accountsManager;
|
||||
|
||||
@Mock
|
||||
private KeysManager keysManager;
|
||||
|
||||
private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub;
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
keysManager = mock(KeysManager.class);
|
||||
|
||||
final KeysAnonymousGrpcService keysGrpcService =
|
||||
new KeysAnonymousGrpcService(accountsManager, keysManager);
|
||||
|
||||
keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService);
|
||||
@Override
|
||||
protected KeysAnonymousGrpcService createServiceBeforeEachTest() {
|
||||
return new KeysAnonymousGrpcService(accountsManager, keysManager);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -101,7 +92,7 @@ class KeysAnonymousGrpcServiceTest {
|
|||
when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
|
||||
when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey);
|
||||
|
||||
final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
final GetPreKeysResponse response = unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
|
||||
.setRequest(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
|
@ -152,19 +143,15 @@ class KeysAnonymousGrpcServiceTest {
|
|||
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier)))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
final StatusRuntimeException statusRuntimeException =
|
||||
assertThrows(StatusRuntimeException.class,
|
||||
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setRequest(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
.build())
|
||||
.setDeviceId(Device.MASTER_ID)
|
||||
.build())
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setRequest(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
.build())
|
||||
.setDeviceId(Device.MASTER_ID)
|
||||
.build())
|
||||
.build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -174,7 +161,7 @@ class KeysAnonymousGrpcServiceTest {
|
|||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class,
|
||||
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
() -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.setRequest(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
|
@ -205,19 +192,15 @@ class KeysAnonymousGrpcServiceTest {
|
|||
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class,
|
||||
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
|
||||
.setRequest(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(accountIdentifier))
|
||||
.build())
|
||||
.setDeviceId(deviceId)
|
||||
.build())
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
|
||||
.setRequest(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(accountIdentifier))
|
||||
.build())
|
||||
.setDeviceId(deviceId)
|
||||
.build())
|
||||
.build()));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
|
@ -16,9 +15,10 @@ import static org.mockito.Mockito.mock;
|
|||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.ServerInterceptors;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.time.Duration;
|
||||
|
@ -32,14 +32,13 @@ import java.util.UUID;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Stream;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.mockito.Mock;
|
||||
import org.signal.chat.common.EcPreKey;
|
||||
import org.signal.chat.common.EcSignedPreKey;
|
||||
import org.signal.chat.common.KemSignedPreKey;
|
||||
|
@ -56,7 +55,6 @@ import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
|
|||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
|
@ -73,41 +71,34 @@ import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
|||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
class KeysGrpcServiceTest {
|
||||
|
||||
private AccountsManager accountsManager;
|
||||
private KeysManager keysManager;
|
||||
private RateLimiter preKeysRateLimiter;
|
||||
|
||||
private Device authenticatedDevice;
|
||||
|
||||
private KeysGrpc.KeysBlockingStub keysStub;
|
||||
|
||||
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
|
||||
private static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
|
||||
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
|
||||
class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.KeysBlockingStub> {
|
||||
|
||||
private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
|
||||
|
||||
private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
protected static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
keysManager = mock(KeysManager.class);
|
||||
preKeysRateLimiter = mock(RateLimiter.class);
|
||||
@Mock
|
||||
private AccountsManager accountsManager;
|
||||
|
||||
@Mock
|
||||
private KeysManager keysManager;
|
||||
|
||||
@Mock
|
||||
private RateLimiter preKeysRateLimiter;
|
||||
|
||||
@Mock
|
||||
private Device authenticatedDevice;
|
||||
|
||||
|
||||
@Override
|
||||
protected KeysGrpcService createServiceBeforeEachTest() {
|
||||
final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||
when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter);
|
||||
|
||||
when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
|
||||
|
||||
final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters);
|
||||
keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
|
||||
|
||||
authenticatedDevice = mock(Device.class);
|
||||
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
|
||||
|
||||
final Account authenticatedAccount = mock(Account.class);
|
||||
|
@ -119,17 +110,13 @@ class KeysGrpcServiceTest {
|
|||
when(authenticatedAccount.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
|
||||
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
|
||||
|
||||
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor));
|
||||
|
||||
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount));
|
||||
when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount));
|
||||
|
||||
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
|
||||
when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
|
||||
|
||||
return new KeysGrpcService(accountsManager, keysManager, rateLimiters);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -152,7 +139,7 @@ class KeysGrpcServiceTest {
|
|||
.setPniEcPreKeyCount(3)
|
||||
.setPniKemPreKeyCount(4)
|
||||
.build(),
|
||||
keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
|
||||
authenticatedServiceStub().getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -168,7 +155,7 @@ class KeysGrpcServiceTest {
|
|||
.thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
|
||||
authenticatedServiceStub().setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.addAllPreKeys(preKeys.stream()
|
||||
.map(preKey -> EcPreKey.newBuilder()
|
||||
|
@ -189,10 +176,7 @@ class KeysGrpcServiceTest {
|
|||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) {
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeEcPreKeys(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setOneTimeEcPreKeysWithError() {
|
||||
|
@ -242,7 +226,7 @@ class KeysGrpcServiceTest {
|
|||
.thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setOneTimeKemSignedPreKeys(
|
||||
authenticatedServiceStub().setOneTimeKemSignedPreKeys(
|
||||
SetOneTimeKemSignedPreKeysRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.addAllPreKeys(preKeys.stream()
|
||||
|
@ -265,10 +249,7 @@ class KeysGrpcServiceTest {
|
|||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) {
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setOneTimeKemSignedPreKeys(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setOneTimeKemSignedPreKeysWithError() {
|
||||
|
@ -333,7 +314,7 @@ class KeysGrpcServiceTest {
|
|||
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair);
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
|
||||
authenticatedServiceStub().setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.setSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(signedPreKey.keyId())
|
||||
|
@ -359,7 +340,7 @@ class KeysGrpcServiceTest {
|
|||
@MethodSource
|
||||
void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) {
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request));
|
||||
assertThrows(StatusRuntimeException.class, () -> authenticatedServiceStub().setEcSignedPreKey(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
}
|
||||
|
@ -416,7 +397,7 @@ class KeysGrpcServiceTest {
|
|||
final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair);
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
|
||||
authenticatedServiceStub().setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.setSignedPreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(lastResortPreKey.keyId())
|
||||
|
@ -437,10 +418,7 @@ class KeysGrpcServiceTest {
|
|||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) {
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setKemLastResortPreKey(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setLastResortPreKeyWithError() {
|
||||
|
@ -528,7 +506,7 @@ class KeysGrpcServiceTest {
|
|||
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
|
||||
|
||||
{
|
||||
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(grpcIdentityType)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
|
@ -563,7 +541,7 @@ class KeysGrpcServiceTest {
|
|||
when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
{
|
||||
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(grpcIdentityType)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
|
@ -606,15 +584,12 @@ class KeysGrpcServiceTest {
|
|||
when(accountsManager.getByServiceIdentifierAsync(any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -631,16 +606,13 @@ class KeysGrpcServiceTest {
|
|||
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(accountIdentifier))
|
||||
.build())
|
||||
.setDeviceId(deviceId)
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(accountIdentifier))
|
||||
.build())
|
||||
.setDeviceId(deviceId)
|
||||
.build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -655,22 +627,15 @@ class KeysGrpcServiceTest {
|
|||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
final Duration retryAfterDuration = Duration.ofMinutes(7);
|
||||
|
||||
when(preKeysRateLimiter.validateReactive(anyString()))
|
||||
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
|
||||
assertNotNull(exception.getTrailers());
|
||||
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
|
||||
|
||||
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
verifyNoInteractions(accountsManager);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,19 +1,27 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException;
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.stub.MetadataUtils;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.SecureRandom;
|
||||
import java.time.Instant;
|
||||
|
@ -24,14 +32,12 @@ import java.util.Optional;
|
|||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Stream;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import io.grpc.stub.MetadataUtils;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import javax.annotation.Nullable;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.mockito.Mock;
|
||||
import org.signal.chat.common.IdentityType;
|
||||
import org.signal.chat.common.ServiceIdentifier;
|
||||
import org.signal.chat.profile.CredentialType;
|
||||
|
@ -72,43 +78,41 @@ import org.whispersystems.textsecuregcm.storage.ProfilesManager;
|
|||
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
|
||||
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
public class ProfileAnonymousGrpcServiceTest {
|
||||
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
|
||||
|
||||
@Mock
|
||||
private Account account;
|
||||
|
||||
@Mock
|
||||
private AccountsManager accountsManager;
|
||||
|
||||
@Mock
|
||||
private ProfilesManager profilesManager;
|
||||
|
||||
@Mock
|
||||
private ProfileBadgeConverter profileBadgeConverter;
|
||||
private ProfileAnonymousGrpc.ProfileAnonymousBlockingStub profileAnonymousBlockingStub;
|
||||
|
||||
@Mock
|
||||
private ServerZkProfileOperations serverZkProfileOperations;
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
account = mock(Account.class);
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
profilesManager = mock(ProfilesManager.class);
|
||||
profileBadgeConverter = mock(ProfileBadgeConverter.class);
|
||||
serverZkProfileOperations = mock(ServerZkProfileOperations.class);
|
||||
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
|
||||
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
|
||||
|
||||
profileAnonymousBlockingStub = ProfileAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
|
||||
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
|
||||
|
||||
final ProfileAnonymousGrpcService profileAnonymousGrpcService = new ProfileAnonymousGrpcService(
|
||||
|
||||
@Override
|
||||
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
|
||||
return new ProfileAnonymousGrpcService(
|
||||
accountsManager,
|
||||
profilesManager,
|
||||
profileBadgeConverter,
|
||||
serverZkProfileOperations
|
||||
);
|
||||
}
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry()
|
||||
.addService(profileAnonymousGrpcService);
|
||||
@Override
|
||||
protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
|
||||
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
|
||||
return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -151,7 +155,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final GetUnversionedProfileResponse response = profileAnonymousBlockingStub.getUnversionedProfile(request);
|
||||
final GetUnversionedProfileResponse response = unauthenticatedServiceStub().getUnversionedProfile(request);
|
||||
|
||||
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
|
||||
final GetUnversionedProfileResponse expectedResponse = GetUnversionedProfileResponse.newBuilder()
|
||||
|
@ -189,10 +193,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
|
||||
}
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileAnonymousBlockingStub.getUnversionedProfile(requestBuilder.build()));
|
||||
|
||||
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getUnversionedProfile(requestBuilder.build()));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> getUnversionedProfileUnauthenticated() {
|
||||
|
@ -242,7 +243,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final GetVersionedProfileResponse response = profileAnonymousBlockingStub.getVersionedProfile(request);
|
||||
final GetVersionedProfileResponse response = unauthenticatedServiceStub().getVersionedProfile(request);
|
||||
|
||||
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
|
||||
.setName(ByteString.copyFrom(name))
|
||||
|
@ -287,10 +288,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileAnonymousBlockingStub.getVersionedProfile(request));
|
||||
|
||||
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getVersionedProfile(request));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -318,18 +316,15 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
|
||||
}
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileAnonymousBlockingStub.getVersionedProfile(requestBuilder.build()));
|
||||
|
||||
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getVersionedProfile(requestBuilder.build()));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> getVersionedProfileUnauthenticated() {
|
||||
return Stream.of(
|
||||
Arguments.of(true, false),
|
||||
Arguments.of(false, true)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getVersionedProfilePniInvalidArgument() {
|
||||
final byte[] unidentifiedAccessKey = new byte[16];
|
||||
|
@ -346,10 +341,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileAnonymousBlockingStub.getVersionedProfile(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getVersionedProfile(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -404,7 +396,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
|
||||
.build();
|
||||
|
||||
final GetExpiringProfileKeyCredentialResponse response = profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request);
|
||||
final GetExpiringProfileKeyCredentialResponse response = unauthenticatedServiceStub().getExpiringProfileKeyCredential(request);
|
||||
|
||||
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
|
||||
|
||||
|
@ -442,10 +434,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey));
|
||||
}
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(requestBuilder.build()));
|
||||
|
||||
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(requestBuilder.build()));
|
||||
|
||||
verifyNoInteractions(profilesManager);
|
||||
}
|
||||
|
@ -483,10 +472,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
|
||||
|
||||
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -521,10 +507,7 @@ public class ProfileAnonymousGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileAnonymousBlockingStub.getExpiringProfileKeyCredential(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().getExpiringProfileKeyCredential(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() {
|
||||
|
|
|
@ -10,8 +10,6 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatNoException
|
|||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
|
@ -21,15 +19,14 @@ import static org.mockito.Mockito.mock;
|
|||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
|
||||
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
|
||||
|
||||
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||
import com.google.i18n.phonenumbers.Phonenumber;
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerInterceptors;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import io.grpc.stub.MetadataUtils;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.SecureRandom;
|
||||
import java.time.Clock;
|
||||
|
@ -44,15 +41,14 @@ import java.util.UUID;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.signal.chat.common.IdentityType;
|
||||
import org.signal.chat.common.ServiceIdentifier;
|
||||
import org.signal.chat.profile.CredentialType;
|
||||
|
@ -82,7 +78,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest;
|
|||
import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext;
|
||||
import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations;
|
||||
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.badges.ProfileBadgeConverter;
|
||||
import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
|
||||
|
@ -100,49 +95,54 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
|
|||
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
|
||||
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
import reactor.core.publisher.Mono;
|
||||
import software.amazon.awssdk.services.s3.S3AsyncClient;
|
||||
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
|
||||
|
||||
public class ProfileGrpcServiceTest {
|
||||
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
|
||||
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
|
||||
public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcService, ProfileGrpc.ProfileBlockingStub> {
|
||||
|
||||
private static final String S3_BUCKET = "profileBucket";
|
||||
|
||||
private static final String VERSION = "someVersion";
|
||||
|
||||
private static final byte[] VALID_NAME = new byte[81];
|
||||
|
||||
@Mock
|
||||
private AccountsManager accountsManager;
|
||||
|
||||
@Mock
|
||||
private ProfilesManager profilesManager;
|
||||
|
||||
@Mock
|
||||
private DynamicPaymentsConfiguration dynamicPaymentsConfiguration;
|
||||
|
||||
@Mock
|
||||
private S3AsyncClient asyncS3client;
|
||||
|
||||
@Mock
|
||||
private VersionedProfile profile;
|
||||
|
||||
@Mock
|
||||
private Account account;
|
||||
|
||||
@Mock
|
||||
private RateLimiter rateLimiter;
|
||||
|
||||
@Mock
|
||||
private ProfileBadgeConverter profileBadgeConverter;
|
||||
|
||||
@Mock
|
||||
private ServerZkProfileOperations serverZkProfileOperations;
|
||||
private ProfileGrpc.ProfileBlockingStub profileBlockingStub;
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
profilesManager = mock(ProfilesManager.class);
|
||||
dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class);
|
||||
asyncS3client = mock(S3AsyncClient.class);
|
||||
profile = mock(VersionedProfile.class);
|
||||
account = mock(Account.class);
|
||||
rateLimiter = mock(RateLimiter.class);
|
||||
profileBadgeConverter = mock(ProfileBadgeConverter.class);
|
||||
serverZkProfileOperations = mock(ServerZkProfileOperations.class);
|
||||
|
||||
@Override
|
||||
protected ProfileGrpcService createServiceBeforeEachTest() {
|
||||
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
|
||||
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
|
||||
final PolicySigner policySigner = new PolicySigner("accessSecret", "us-west-1");
|
||||
|
@ -170,30 +170,6 @@ public class ProfileGrpcServiceTest {
|
|||
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
|
||||
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
|
||||
|
||||
profileBlockingStub = ProfileGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel())
|
||||
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
|
||||
|
||||
final ProfileGrpcService profileGrpcService = new ProfileGrpcService(
|
||||
Clock.systemUTC(),
|
||||
accountsManager,
|
||||
profilesManager,
|
||||
dynamicConfigurationManager,
|
||||
badgesConfiguration,
|
||||
asyncS3client,
|
||||
policyGenerator,
|
||||
policySigner,
|
||||
profileBadgeConverter,
|
||||
rateLimiters,
|
||||
serverZkProfileOperations,
|
||||
S3_BUCKET
|
||||
);
|
||||
|
||||
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(profileGrpcService, mockAuthenticationInterceptor));
|
||||
|
||||
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
|
||||
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());
|
||||
|
||||
|
@ -218,6 +194,21 @@ public class ProfileGrpcServiceTest {
|
|||
when(dynamicPaymentsConfiguration.getDisallowedPrefixes()).thenReturn(Collections.emptyList());
|
||||
|
||||
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
return new ProfileGrpcService(
|
||||
Clock.systemUTC(),
|
||||
accountsManager,
|
||||
profilesManager,
|
||||
dynamicConfigurationManager,
|
||||
badgesConfiguration,
|
||||
asyncS3client,
|
||||
policyGenerator,
|
||||
policySigner,
|
||||
profileBadgeConverter,
|
||||
rateLimiters,
|
||||
serverZkProfileOperations,
|
||||
S3_BUCKET
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -237,7 +228,7 @@ public class ProfileGrpcServiceTest {
|
|||
.setCommitment(ByteString.copyFrom(commitment))
|
||||
.build();
|
||||
|
||||
profileBlockingStub.setProfile(request);
|
||||
authenticatedServiceStub().setProfile(request);
|
||||
|
||||
final ArgumentCaptor<VersionedProfile> profileArgumentCaptor = ArgumentCaptor.forClass(VersionedProfile.class);
|
||||
|
||||
|
@ -274,7 +265,7 @@ public class ProfileGrpcServiceTest {
|
|||
hasPreviousProfile ? Optional.of(profile) : Optional.empty()));
|
||||
when(profilesManager.setAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
SetProfileResponse response = profileBlockingStub.setProfile(request);
|
||||
final SetProfileResponse response = authenticatedServiceStub().setProfile(request);
|
||||
|
||||
if (expectHasS3UploadPath) {
|
||||
assertTrue(response.getAttributes().getPath().startsWith("profiles/"));
|
||||
|
@ -312,10 +303,7 @@ public class ProfileGrpcServiceTest {
|
|||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setProfileInvalidRequestData(final SetProfileRequest request) {
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setProfile(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setProfileInvalidRequestData() throws InvalidInputException{
|
||||
|
@ -386,12 +374,10 @@ public class ProfileGrpcServiceTest {
|
|||
when(profilesManager.getAsync(any(), anyString())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
|
||||
|
||||
if (hasExistingPaymentAddress) {
|
||||
assertDoesNotThrow(() -> profileBlockingStub.setProfile(request),
|
||||
assertDoesNotThrow(() -> authenticatedServiceStub().setProfile(request),
|
||||
"Payment address changes in disallowed countries should still be allowed if the account already has a valid payment address");
|
||||
} else {
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.setProfile(request));
|
||||
assertEquals(Status.PERMISSION_DENIED.getCode(), exception.getStatus().getCode());
|
||||
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().setProfile(request));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -433,7 +419,7 @@ public class ProfileGrpcServiceTest {
|
|||
when(profileBadgeConverter.convert(any(), any(), anyBoolean())).thenReturn(badges);
|
||||
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
|
||||
|
||||
final GetUnversionedProfileResponse response = profileBlockingStub.getUnversionedProfile(request);
|
||||
final GetUnversionedProfileResponse response = authenticatedServiceStub().getUnversionedProfile(request);
|
||||
|
||||
final byte[] unidentifiedAccessChecksum = UnidentifiedAccessChecksum.generateFor(unidentifiedAccessKey);
|
||||
final GetUnversionedProfileResponse prototypeExpectedResponse = GetUnversionedProfileResponse.newBuilder()
|
||||
|
@ -472,10 +458,7 @@ public class ProfileGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileBlockingStub.getUnversionedProfile(request));
|
||||
|
||||
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getUnversionedProfile(request));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -493,14 +476,7 @@ public class ProfileGrpcServiceTest {
|
|||
.build())
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> profileBlockingStub.getUnversionedProfile(request));
|
||||
|
||||
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
|
||||
assertNotNull(exception.getTrailers());
|
||||
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
|
||||
|
||||
verifyNoInteractions(accountsManager);
|
||||
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getUnversionedProfile(request), accountsManager);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -531,7 +507,7 @@ public class ProfileGrpcServiceTest {
|
|||
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
|
||||
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
|
||||
|
||||
final GetVersionedProfileResponse response = profileBlockingStub.getVersionedProfile(request);
|
||||
final GetVersionedProfileResponse response = authenticatedServiceStub().getVersionedProfile(request);
|
||||
|
||||
final GetVersionedProfileResponse.Builder expectedResponseBuilder = GetVersionedProfileResponse.newBuilder()
|
||||
.setName(ByteString.copyFrom(name))
|
||||
|
@ -545,7 +521,6 @@ public class ProfileGrpcServiceTest {
|
|||
|
||||
assertEquals(expectedResponseBuilder.build(), response);
|
||||
}
|
||||
|
||||
private static Stream<Arguments> getVersionedProfile() {
|
||||
return Stream.of(
|
||||
Arguments.of("version1", "version1", true),
|
||||
|
@ -553,6 +528,7 @@ public class ProfileGrpcServiceTest {
|
|||
Arguments.of("version1", "version2", false)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getVersionedProfileAccountOrProfileNotFound(final boolean missingAccount, final boolean missingProfile) {
|
||||
|
@ -566,10 +542,7 @@ public class ProfileGrpcServiceTest {
|
|||
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(missingAccount ? Optional.empty() : Optional.of(account)));
|
||||
when(profilesManager.getAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(missingProfile ? Optional.empty() : Optional.of(profile)));
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileBlockingStub.getVersionedProfile(request));
|
||||
|
||||
assertEquals(Status.NOT_FOUND.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getVersionedProfile(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> getVersionedProfileAccountOrProfileNotFound() {
|
||||
|
@ -581,10 +554,7 @@ public class ProfileGrpcServiceTest {
|
|||
|
||||
@Test
|
||||
void getVersionedProfileRatelimited() {
|
||||
final Duration retryAfterDuration = Duration.ofMinutes(7);
|
||||
|
||||
when(rateLimiter.validateReactive(any(UUID.class)))
|
||||
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
|
||||
final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(7), false);
|
||||
|
||||
final GetVersionedProfileRequest request = GetVersionedProfileRequest.newBuilder()
|
||||
.setAccountIdentifier(ServiceIdentifier.newBuilder()
|
||||
|
@ -594,15 +564,7 @@ public class ProfileGrpcServiceTest {
|
|||
.setVersion("someVersion")
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileBlockingStub.getVersionedProfile(request));
|
||||
|
||||
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
|
||||
assertNotNull(exception.getTrailers());
|
||||
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
|
||||
|
||||
verifyNoInteractions(accountsManager);
|
||||
verifyNoInteractions(profilesManager);
|
||||
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getVersionedProfile(request), accountsManager, profilesManager);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -615,9 +577,7 @@ public class ProfileGrpcServiceTest {
|
|||
.setVersion("someVersion")
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileBlockingStub.getVersionedProfile(request));
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getVersionedProfile(request));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -664,7 +624,7 @@ public class ProfileGrpcServiceTest {
|
|||
.setVersion("someVersion")
|
||||
.build();
|
||||
|
||||
final GetExpiringProfileKeyCredentialResponse response = profileBlockingStub.getExpiringProfileKeyCredential(request);
|
||||
final GetExpiringProfileKeyCredentialResponse response = authenticatedServiceStub().getExpiringProfileKeyCredential(request);
|
||||
|
||||
assertArrayEquals(credentialResponse.serialize(), response.getProfileKeyCredential().toByteArray());
|
||||
|
||||
|
@ -677,9 +637,8 @@ public class ProfileGrpcServiceTest {
|
|||
|
||||
@Test
|
||||
void getExpiringProfileKeyCredentialRateLimited() {
|
||||
final Duration retryAfterDuration = Duration.ofMinutes(5);
|
||||
when(rateLimiter.validateReactive(AUTHENTICATED_ACI))
|
||||
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
|
||||
final Duration retryAfterDuration = MockUtils.updateRateLimiterResponseToFail(
|
||||
rateLimiter, AUTHENTICATED_ACI, Duration.ofMinutes(5), false);
|
||||
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
|
||||
|
||||
final GetExpiringProfileKeyCredentialRequest request = GetExpiringProfileKeyCredentialRequest.newBuilder()
|
||||
|
@ -692,14 +651,7 @@ public class ProfileGrpcServiceTest {
|
|||
.setVersion("someVersion")
|
||||
.build();
|
||||
|
||||
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
|
||||
|
||||
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
|
||||
assertNotNull(exception.getTrailers());
|
||||
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
|
||||
|
||||
verifyNoInteractions(profilesManager);
|
||||
assertRateLimitExceeded(retryAfterDuration, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request), profilesManager);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -723,10 +675,7 @@ public class ProfileGrpcServiceTest {
|
|||
.setVersion("someVersion")
|
||||
.build();
|
||||
|
||||
final StatusRuntimeException statusRuntimeException = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
|
||||
|
||||
assertEquals(Status.Code.NOT_FOUND, statusRuntimeException.getStatus().getCode());
|
||||
assertStatusException(Status.NOT_FOUND, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> getExpiringProfileKeyCredentialAccountOrProfileNotFound() {
|
||||
|
@ -761,10 +710,7 @@ public class ProfileGrpcServiceTest {
|
|||
.setVersion("someVersion")
|
||||
.build();
|
||||
|
||||
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
|
||||
() -> profileBlockingStub.getExpiringProfileKeyCredential(request));
|
||||
|
||||
assertEquals(Status.Code.INVALID_ARGUMENT, exception.getStatus().getCode());
|
||||
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().getExpiringProfileKeyCredential(request));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> getExpiringProfileKeyCredentialInvalidArgument() {
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import io.grpc.BindableService;
|
||||
import io.grpc.Channel;
|
||||
import io.grpc.stub.AbstractBlockingStub;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
/**
|
||||
* Base class for the common case of gRPC services tests. This base class makes some assumptions
|
||||
* and introduces some constraints on the implementing classes with a goal of simplifying the process
|
||||
* of creating a test for the most of the gRPC services.
|
||||
* <ul>
|
||||
* <li>
|
||||
* Test classes extending this class will have to override the {@link #createServiceBeforeEachTest()} method
|
||||
* with the logic that creates an instance of the service to test. This method is called before each test and should
|
||||
* contain other setup code that would normally go into {@code @BeforeEach} method.
|
||||
* </li>
|
||||
* <li>
|
||||
* This base class takes care of creating two service stubs: {@code authenticatedServiceStub} and {@code unauthenticatedServiceStub}.
|
||||
* Normally, those stubs are created by the call to the {@code newBlockingStub()} method on the {@code *Stub} class, e.g.:
|
||||
* <pre>CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());</pre>
|
||||
* In this class, those stubs are created by the {@link #createStub(Channel)} method that has a default implementation that is based on
|
||||
* figuring out the name of the {@code `*Grpc`} class and invoking {@code `*Grpc.newBlockingStub()`} method with reflection.
|
||||
* </li>
|
||||
* <li>
|
||||
* This class takes care of initializing {@code Mockito} annotations processing, so implementing classes
|
||||
* can annotate their fields with {@code @Mock} and have those mocks ready by the time {@link #createServiceBeforeEachTest()} is called.
|
||||
* </li>
|
||||
* </ul>
|
||||
* @param <SERVICE> Class of the gRPC service that is being tested.
|
||||
* @param <STUB> Class of the gRPC service stub.
|
||||
*/
|
||||
public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB extends AbstractBlockingStub<STUB>> {
|
||||
|
||||
@RegisterExtension
|
||||
protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_AUTHENTICATED = new GrpcServerExtension();
|
||||
|
||||
@RegisterExtension
|
||||
protected static final GrpcServerExtension GRPC_SERVER_EXTENSION_UNAUTHENTICATED = new GrpcServerExtension();
|
||||
|
||||
protected static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
|
||||
|
||||
protected static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
|
||||
|
||||
private AutoCloseable mocksCloseable;
|
||||
|
||||
private MockAuthenticationInterceptor mockAuthenticationInterceptor;
|
||||
|
||||
private SERVICE service;
|
||||
|
||||
private STUB authenticatedServiceStub;
|
||||
|
||||
private STUB unauthenticatedServiceStub;
|
||||
|
||||
|
||||
/**
|
||||
* This method is invoked before each test and is expected to create an instance of the gRPC service
|
||||
* that is being tested and also to perform all necessary before-each setup.
|
||||
* </p>
|
||||
* Extending classes may have their own {@code @BeforeEach} method, but it will be called after this method.
|
||||
* @return an instance of the gRPC service.
|
||||
*/
|
||||
protected abstract SERVICE createServiceBeforeEachTest();
|
||||
|
||||
/**
|
||||
* The default implementation of this method is based on figuring out the name of the {@code `*Grpc`} class
|
||||
* and invoking {@code `*Grpc.newBlockingStub()`} method with reflection.
|
||||
* <p>
|
||||
* Overriding this method can be helpful if addutional configuration of the stub is required, e.g. adding interceptors:
|
||||
* <pre>
|
||||
* protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
|
||||
* final Metadata metadata = new Metadata();
|
||||
* metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
|
||||
* metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
|
||||
* return super.createStub(channel)
|
||||
* .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
|
||||
* }
|
||||
* </pre>
|
||||
* @param channel grpc channel to create create the stub for.
|
||||
* @return and instance of the service stub.
|
||||
*/
|
||||
protected STUB createStub(final Channel channel) throws
|
||||
ClassNotFoundException,
|
||||
NoSuchMethodException,
|
||||
InvocationTargetException,
|
||||
IllegalAccessException {
|
||||
final String serviceClassName = service.bindService().getServiceDescriptor().getName();
|
||||
final String grpcClassName = serviceClassName + "Grpc";
|
||||
final Class<?> grpcClass = ClassLoader.getSystemClassLoader().loadClass(grpcClassName);
|
||||
final Method newBlockingStubMethod = grpcClass.getMethod("newBlockingStub", Channel.class);
|
||||
final Object stub = newBlockingStubMethod.invoke(null, channel);
|
||||
//noinspection unchecked
|
||||
return (STUB) stub;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
protected void baseSetup() {
|
||||
mocksCloseable = MockitoAnnotations.openMocks(this);
|
||||
service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`");
|
||||
mockAuthenticationInterceptor = GrpcTestUtils.setupAuthenticatedExtension(
|
||||
GRPC_SERVER_EXTENSION_AUTHENTICATED, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
|
||||
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, service);
|
||||
try {
|
||||
authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
|
||||
unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel());
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Could not create a stub based on the service name. Try overriding `createStub()` method.");
|
||||
}
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void releaseMocks() throws Exception {
|
||||
mocksCloseable.close();
|
||||
}
|
||||
|
||||
public MockAuthenticationInterceptor mockAuthenticationInterceptor() {
|
||||
return mockAuthenticationInterceptor;
|
||||
}
|
||||
|
||||
protected SERVICE service() {
|
||||
return service;
|
||||
}
|
||||
|
||||
protected STUB authenticatedServiceStub() {
|
||||
return authenticatedServiceStub;
|
||||
}
|
||||
|
||||
protected STUB unauthenticatedServiceStub() {
|
||||
return unauthenticatedServiceStub;
|
||||
}
|
||||
}
|
|
@ -111,7 +111,7 @@ public class RateLimitersTest {
|
|||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
|
||||
assertEquals(expected, limiter.config());
|
||||
assertEquals(expected, config(limiter));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -131,16 +131,16 @@ public class RateLimitersTest {
|
|||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||
|
||||
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
|
||||
assertEquals(initialRateLimiterConfig, limiter.config());
|
||||
assertEquals(initialRateLimiterConfig, config(limiter));
|
||||
|
||||
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
|
||||
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
|
||||
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
|
||||
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
|
||||
|
||||
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
|
||||
assertEquals(updatedRateLimiterCongig, limiter.config());
|
||||
assertEquals(updatedRateLimiterCongig, config(limiter));
|
||||
|
||||
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
|
||||
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
|
||||
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeAttemptLimiter()));
|
||||
assertEquals(baseConfig, config(rateLimiters.getRecaptchaChallengeSuccessLimiter()));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -161,22 +161,22 @@ public class RateLimitersTest {
|
|||
// test only default is present
|
||||
mapForDynamic.remove(descriptor.id());
|
||||
mapForStatic.remove(descriptor.id());
|
||||
assertEquals(defaultConfig, limiter.config());
|
||||
assertEquals(defaultConfig, config(limiter));
|
||||
|
||||
// test dynamic and no static
|
||||
mapForDynamic.put(descriptor.id(), configForDynamic);
|
||||
mapForStatic.remove(descriptor.id());
|
||||
assertEquals(configForDynamic, limiter.config());
|
||||
assertEquals(configForDynamic, config(limiter));
|
||||
|
||||
// test dynamic and static
|
||||
mapForDynamic.put(descriptor.id(), configForDynamic);
|
||||
mapForStatic.put(descriptor.id(), configForStatic);
|
||||
assertEquals(configForDynamic, limiter.config());
|
||||
assertEquals(configForDynamic, config(limiter));
|
||||
|
||||
// test static, but no dynamic
|
||||
mapForDynamic.remove(descriptor.id());
|
||||
mapForStatic.put(descriptor.id(), configForStatic);
|
||||
assertEquals(configForStatic, limiter.config());
|
||||
assertEquals(configForStatic, config(limiter));
|
||||
}
|
||||
|
||||
private record TestDescriptor(String id) implements RateLimiterDescriptor {
|
||||
|
@ -191,4 +191,14 @@ public class RateLimitersTest {
|
|||
return new RateLimiterConfig(1, Duration.ofMinutes(1));
|
||||
}
|
||||
}
|
||||
|
||||
private static RateLimiterConfig config(final RateLimiter rateLimiter) {
|
||||
if (rateLimiter instanceof StaticRateLimiter rm) {
|
||||
return rm.config();
|
||||
}
|
||||
if (rateLimiter instanceof DynamicRateLimiter rm) {
|
||||
return rm.config();
|
||||
}
|
||||
throw new IllegalArgumentException("Rate limiter is of an unexpected type: " + rateLimiter.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,12 +12,14 @@ import static org.mockito.Mockito.doThrow;
|
|||
|
||||
import java.time.Duration;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.apache.commons.lang3.RandomUtils;
|
||||
import org.mockito.Mockito;
|
||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
public final class MockUtils {
|
||||
|
||||
|
@ -46,32 +48,80 @@ public final class MockUtils {
|
|||
}
|
||||
|
||||
public static void updateRateLimiterResponseToAllow(
|
||||
final RateLimiters rateLimitersMock,
|
||||
final RateLimiters.For handle,
|
||||
final RateLimiter mockRateLimiter,
|
||||
final String input) {
|
||||
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
||||
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
|
||||
try {
|
||||
doNothing().when(mockRateLimiter).validate(eq(input));
|
||||
doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
|
||||
doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
|
||||
} catch (final RateLimitExceededException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void updateRateLimiterResponseToAllow(
|
||||
final RateLimiter mockRateLimiter,
|
||||
final UUID input) {
|
||||
try {
|
||||
doNothing().when(mockRateLimiter).validate(eq(input));
|
||||
doReturn(CompletableFuture.completedFuture(null)).when(mockRateLimiter).validateAsync(eq(input));
|
||||
doReturn(Mono.fromFuture(CompletableFuture.completedFuture(null))).when(mockRateLimiter).validateReactive(eq(input));
|
||||
} catch (final RateLimitExceededException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void updateRateLimiterResponseToAllow(
|
||||
final RateLimiters rateLimitersMock,
|
||||
final RateLimiters.For handle,
|
||||
final String input) {
|
||||
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
||||
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
|
||||
updateRateLimiterResponseToAllow(mockRateLimiter, input);
|
||||
}
|
||||
|
||||
public static void updateRateLimiterResponseToAllow(
|
||||
final RateLimiters rateLimitersMock,
|
||||
final RateLimiters.For handle,
|
||||
final UUID input) {
|
||||
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
||||
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
|
||||
updateRateLimiterResponseToAllow(mockRateLimiter, input);
|
||||
}
|
||||
|
||||
public static Duration updateRateLimiterResponseToFail(
|
||||
final RateLimiter mockRateLimiter,
|
||||
final String input,
|
||||
final Duration retryAfter,
|
||||
final boolean legacyStatusCode) {
|
||||
try {
|
||||
doNothing().when(mockRateLimiter).validate(eq(input));
|
||||
final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
|
||||
doThrow(exception).when(mockRateLimiter).validate(eq(input));
|
||||
doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
|
||||
doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
|
||||
return retryAfter;
|
||||
} catch (final RateLimitExceededException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void updateRateLimiterResponseToFail(
|
||||
public static Duration updateRateLimiterResponseToFail(
|
||||
final RateLimiter mockRateLimiter,
|
||||
final UUID input,
|
||||
final Duration retryAfter,
|
||||
final boolean legacyStatusCode) {
|
||||
try {
|
||||
final RateLimitExceededException exception = new RateLimitExceededException(retryAfter, legacyStatusCode);
|
||||
doThrow(exception).when(mockRateLimiter).validate(eq(input));
|
||||
doReturn(CompletableFuture.failedFuture(exception)).when(mockRateLimiter).validateAsync(eq(input));
|
||||
doReturn(Mono.fromFuture(CompletableFuture.failedFuture(exception))).when(mockRateLimiter).validateReactive(eq(input));
|
||||
return retryAfter;
|
||||
} catch (final RateLimitExceededException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static Duration updateRateLimiterResponseToFail(
|
||||
final RateLimiters rateLimitersMock,
|
||||
final RateLimiters.For handle,
|
||||
final String input,
|
||||
|
@ -79,14 +129,10 @@ public final class MockUtils {
|
|||
final boolean legacyStatusCode) {
|
||||
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
||||
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
|
||||
try {
|
||||
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
|
||||
} catch (final RateLimitExceededException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
|
||||
}
|
||||
|
||||
public static void updateRateLimiterResponseToFail(
|
||||
public static Duration updateRateLimiterResponseToFail(
|
||||
final RateLimiters rateLimitersMock,
|
||||
final RateLimiters.For handle,
|
||||
final UUID input,
|
||||
|
@ -94,11 +140,7 @@ public final class MockUtils {
|
|||
final boolean legacyStatusCode) {
|
||||
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
||||
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
|
||||
try {
|
||||
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
|
||||
} catch (final RateLimitExceededException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return updateRateLimiterResponseToFail(mockRateLimiter, input, retryAfter, legacyStatusCode);
|
||||
}
|
||||
|
||||
public static SecretBytes randomSecretBytes(final int size) {
|
||||
|
|
Loading…
Reference in New Issue