diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 5e50f91fe..2dc69451e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -739,7 +739,7 @@ public class WhisperServerService extends Application 0 + final String username = shouldDeriveUsername() ? hmac256TruncatedToHexString(userDerivationKey, identity, TRUNCATE_LENGTH) : identity; - final long currentTimeSeconds = TimeUnit.MILLISECONDS.toSeconds(clock.millis()); + final long currentTimeSeconds = currentTimeSeconds(); - final String dataToSign = username + ":" + currentTimeSeconds; + final String dataToSign = username + DELIMITER + currentTimeSeconds; final String signature = truncateSignature ? hmac256TruncatedToHexString(key, dataToSign, TRUNCATE_LENGTH) : hmac256ToHexString(key, dataToSign); - final String token = (prependUsername ? dataToSign : currentTimeSeconds) + ":" + signature; + final String token = (prependUsername ? dataToSign : currentTimeSeconds) + DELIMITER + signature; return new ExternalServiceCredentials(username, token); } + /** + * In certain cases, identity (as it was passed to `generateFor` method) + * is a part of the signature (`password`, in terms of `ExternalServiceCredentials`) string itself. + * For such cases, this method returns the value of the identity string. + * @param password `password` part of `ExternalServiceCredentials` + * @return non-empty optional with an identity string value, or empty if value can't be extracted. + */ + public Optional identityFromSignature(final String password) { + // for some generators, identity in the clear is just not a part of the password + if (!prependUsername || shouldDeriveUsername() || StringUtils.isBlank(password)) { + return Optional.empty(); + } + // checking for the case of unexpected format + return StringUtils.countMatches(password, DELIMITER) == 2 + ? Optional.of(password.substring(0, password.indexOf(DELIMITER))) + : Optional.empty(); + } + + /** + * Given an instance of {@link ExternalServiceCredentials} object, checks that the password + * matches the username taking into accound this generator's configuration. + * @param credentials an instance of {@link ExternalServiceCredentials} + * @return An optional with a timestamp (seconds) of when the credentials were generated, + * or an empty optional if the password doesn't match the username for any reason (including malformed data) + */ + public Optional validateAndGetTimestamp(final ExternalServiceCredentials credentials) { + final String[] parts = requireNonNull(credentials).password().split(DELIMITER); + final String timestampSeconds; + final String actualSignature; + + // making sure password format matches our expectations based on the generator configuration + if (parts.length == 3 && prependUsername) { + final String username = parts[0]; + // username has to match the one from `credentials` + if (!credentials.username().equals(username)) { + return Optional.empty(); + } + timestampSeconds = parts[1]; + actualSignature = parts[2]; + } else if (parts.length == 2 && !prependUsername) { + timestampSeconds = parts[0]; + actualSignature = parts[1]; + } else { + // unexpected password format + return Optional.empty(); + } + + final String signedData = credentials.username() + DELIMITER + timestampSeconds; + final String expectedSignature = truncateSignature + ? hmac256TruncatedToHexString(key, signedData, TRUNCATE_LENGTH) + : hmac256ToHexString(key, signedData); + + // if the signature is valid it's safe to parse the `timestampSeconds` string into Long + return hmacHexStringsEqual(expectedSignature, actualSignature) + ? Optional.of(Long.valueOf(timestampSeconds)) + : Optional.empty(); + } + + /** + * Given an instance of {@link ExternalServiceCredentials} object and the max allowed age for those credentials, + * checks if credentials are valid and not expired. + * @param credentials an instance of {@link ExternalServiceCredentials} + * @param maxAgeSeconds age in seconds + * @return An optional with a timestamp (seconds) of when the credentials were generated, + * or an empty optional if the password doesn't match the username for any reason (including malformed data) + */ + public Optional validateAndGetTimestamp(final ExternalServiceCredentials credentials, final long maxAgeSeconds) { + return validateAndGetTimestamp(credentials) + .filter(ts -> currentTimeSeconds() - ts <= maxAgeSeconds); + } + + private boolean shouldDeriveUsername() { + return userDerivationKey.length > 0; + } + + private long currentTimeSeconds() { + return clock.instant().getEpochSecond(); + } + public static class Builder { private final byte[] key; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java index a8a4bcfac..f98701542 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.configuration; @@ -74,6 +74,9 @@ public class RateLimitsConfiguration { @JsonProperty private RateLimitConfiguration stories = new RateLimitConfiguration(10_000, 10_000 / (24.0 * 60.0)); + @JsonProperty + private RateLimitConfiguration backupAuthCheck = new RateLimitConfiguration(100, 100 / (24.0 * 60.0)); + public RateLimitConfiguration getAutoBlock() { return autoBlock; } @@ -158,7 +161,13 @@ public class RateLimitsConfiguration { return checkAccountExistence; } - public RateLimitConfiguration getStories() { return stories; } + public RateLimitConfiguration getStories() { + return stories; + } + + public RateLimitConfiguration getBackupAuthCheck() { + return backupAuthCheck; + } public static class RateLimitConfiguration { @JsonProperty diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index d9170fcf3..8e72c4d70 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -87,6 +87,7 @@ import org.whispersystems.textsecuregcm.entities.ReserveUsernameResponse; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.UsernameRequest; import org.whispersystems.textsecuregcm.entities.UsernameResponse; +import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; @@ -772,9 +773,9 @@ public class AccountController { @GET @Path("/username/{username}") @Produces(MediaType.APPLICATION_JSON) + @RateLimitedByIp(RateLimiters.Handle.USERNAME_LOOKUP) public AccountIdentifierResponse lookupUsername( @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String userAgent, - @HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor, @PathParam("username") final String username, @Context final HttpServletRequest request) throws RateLimitExceededException { @@ -783,8 +784,6 @@ public class AccountController { throw new BadRequestException(); } - rateLimitByClientIp(rateLimiters.getUsernameLookupLimiter(), forwardedFor); - checkUsername(username, userAgent); return accounts @@ -796,8 +795,8 @@ public class AccountController { @HEAD @Path("/account/{uuid}") + @RateLimitedByIp(RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE) public Response accountExists( - @HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor, @PathParam("uuid") final UUID uuid, @Context HttpServletRequest request) throws RateLimitExceededException { @@ -805,7 +804,6 @@ public class AccountController { if (StringUtils.isNotBlank(request.getHeader("Authorization"))) { throw new BadRequestException(); } - rateLimitByClientIp(rateLimiters.getCheckAccountExistenceLimiter(), forwardedFor); final Status status = accounts.getByAccountIdentifier(uuid) .or(() -> accounts.getByPhoneNumberIdentifier(uuid)) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java index a284576f9..863f20ab7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java @@ -5,40 +5,151 @@ package org.whispersystems.textsecuregcm.controllers; +import static java.util.Objects.requireNonNull; + import com.codahale.metrics.annotation.Timed; import io.dropwizard.auth.Auth; +import java.time.Clock; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import javax.validation.Valid; +import javax.validation.constraints.NotNull; +import javax.ws.rs.Consumes; import javax.ws.rs.GET; +import javax.ws.rs.POST; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; import org.apache.commons.codec.DecoderException; +import org.apache.commons.lang3.tuple.Pair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; import org.whispersystems.textsecuregcm.configuration.SecureBackupServiceConfiguration; +import org.whispersystems.textsecuregcm.entities.AuthCheckRequest; +import org.whispersystems.textsecuregcm.entities.AuthCheckResponse; +import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.util.UUIDUtil; @Path("/v1/backup") public class SecureBackupController { - private final ExternalServiceCredentialsGenerator backupServiceCredentialsGenerator; + private static final long MAX_AGE_SECONDS = TimeUnit.DAYS.toSeconds(30); - public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureBackupServiceConfiguration cfg) - throws DecoderException { - return ExternalServiceCredentialsGenerator - .builder(cfg.getUserAuthenticationTokenSharedSecret()) - .prependUsername(true) - .build(); + private final ExternalServiceCredentialsGenerator credentialsGenerator; + + private final AccountsManager accountsManager; + + public static ExternalServiceCredentialsGenerator credentialsGenerator(final SecureBackupServiceConfiguration cfg) { + return credentialsGenerator(cfg, Clock.systemUTC()); } - public SecureBackupController(ExternalServiceCredentialsGenerator backupServiceCredentialsGenerator) { - this.backupServiceCredentialsGenerator = backupServiceCredentialsGenerator; + public static ExternalServiceCredentialsGenerator credentialsGenerator( + final SecureBackupServiceConfiguration cfg, + final Clock clock) { + try { + return ExternalServiceCredentialsGenerator + .builder(cfg.getUserAuthenticationTokenSharedSecret()) + .prependUsername(true) + .withClock(clock) + .build(); + } catch (final DecoderException e) { + throw new IllegalStateException(e); + } + } + + public SecureBackupController( + final ExternalServiceCredentialsGenerator credentialsGenerator, + final AccountsManager accountsManager) { + this.credentialsGenerator = requireNonNull(credentialsGenerator); + this.accountsManager = requireNonNull(accountsManager); } @Timed @GET @Path("/auth") @Produces(MediaType.APPLICATION_JSON) - public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) { - return backupServiceCredentialsGenerator.generateForUuid(auth.getAccount().getUuid()); + public ExternalServiceCredentials getAuth(final @Auth AuthenticatedAccount auth) { + return credentialsGenerator.generateForUuid(auth.getAccount().getUuid()); + } + + @Timed + @POST + @Path("/auth/check") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + @RateLimitedByIp(RateLimiters.Handle.BACKUP_AUTH_CHECK) + public AuthCheckResponse authCheck(@NotNull @Valid final AuthCheckRequest request) { + final Map results = new HashMap<>(); + final Map> tokenToUuid = new HashMap<>(); + final Map uuidToLatestTimestamp = new HashMap<>(); + + // first pass -- filter out all tokens that contain invalid credentials + // (this could be either legit but expired or illegitimate for any reason) + request.passwords().forEach(token -> { + // each token is supposed to be in a "${username}:${password}" form, + // (note that password part may also contain ':' characters) + final String[] parts = token.split(":", 2); + if (parts.length != 2) { + results.put(token, AuthCheckResponse.Result.INVALID); + return; + } + final ExternalServiceCredentials credentials = new ExternalServiceCredentials(parts[0], parts[1]); + final Optional maybeTimestamp = credentialsGenerator.validateAndGetTimestamp(credentials, MAX_AGE_SECONDS); + final Optional maybeUuid = UUIDUtil.fromStringSafe(credentials.username()); + if (maybeTimestamp.isEmpty() || maybeUuid.isEmpty()) { + results.put(token, AuthCheckResponse.Result.INVALID); + return; + } + // now that we validated signature and token age, we will also find the latest of the tokens + // for each username + final Long timestamp = maybeTimestamp.get(); + final UUID uuid = maybeUuid.get(); + tokenToUuid.put(token, Pair.of(uuid, timestamp)); + final Long latestTimestamp = uuidToLatestTimestamp.getOrDefault(uuid, 0L); + if (timestamp > latestTimestamp) { + uuidToLatestTimestamp.put(uuid, timestamp); + } + }); + + // as a result of the first pass we now have some tokens that are marked invalid, + // and for others we now know if for any username the list contains multiple tokens + // we also know all distinct usernames from the list + + // if it so happens that all tokens are invalid -- respond right away + if (tokenToUuid.isEmpty()) { + return new AuthCheckResponse(results); + } + + final Predicate uuidMatches = accountsManager + .getByE164(request.number()) + .map(account -> (Predicate) candidateUuid -> account.getUuid().equals(candidateUuid)) + .orElse(candidateUuid -> false); + + // second pass will let us discard tokens that have newer versions and will also let us pick the winner (if any) + request.passwords().forEach(token -> { + if (results.containsKey(token)) { + // result already calculated + return; + } + final Pair uuidAndTime = requireNonNull(tokenToUuid.get(token)); + final Long latestTimestamp = requireNonNull(uuidToLatestTimestamp.get(uuidAndTime.getLeft())); + // check if a newer version available + if (uuidAndTime.getRight() < latestTimestamp) { + results.put(token, AuthCheckResponse.Result.INVALID); + return; + } + results.put(token, uuidMatches.test(uuidAndTime.getLeft()) + ? AuthCheckResponse.Result.MATCH + : AuthCheckResponse.Result.NO_MATCH); + }); + + return new AuthCheckResponse(results); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AuthCheckRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AuthCheckRequest.java new file mode 100644 index 000000000..e8a4a3500 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AuthCheckRequest.java @@ -0,0 +1,16 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import java.util.List; +import javax.validation.constraints.NotEmpty; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Size; +import org.whispersystems.textsecuregcm.util.E164; + +public record AuthCheckRequest(@NotNull @E164 String number, + @NotEmpty @Size(max = 10) List passwords) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AuthCheckResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AuthCheckResponse.java new file mode 100644 index 000000000..a9645b0ea --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AuthCheckResponse.java @@ -0,0 +1,30 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonValue; +import java.util.Map; +import javax.validation.constraints.NotNull; + +public record AuthCheckResponse(@NotNull Map matches) { + + public enum Result { + MATCH("match"), + NO_MATCH("no-match"), + INVALID("invalid"); + + private final String clientCode; + + Result(final String clientCode) { + this.clientCode = clientCode; + } + + @JsonValue + public String clientCode() { + return clientCode; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java new file mode 100644 index 000000000..3147073b7 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HttpHeaders; +import java.io.IOException; +import java.time.Duration; +import java.util.Optional; +import javax.ws.rs.ClientErrorException; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.container.ContainerRequestFilter; +import javax.ws.rs.core.Response; +import javax.ws.rs.ext.ExceptionMapper; +import org.glassfish.jersey.server.ExtendedUriInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.util.HeaderUtils; + +public class RateLimitByIpFilter implements ContainerRequestFilter { + + private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class); + + @VisibleForTesting + static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1)); + + private static final ExceptionMapper EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper(); + + private final RateLimiters rateLimiters; + + + public RateLimitByIpFilter(final RateLimiters rateLimiters) { + this.rateLimiters = requireNonNull(rateLimiters); + } + + @Override + public void filter(final ContainerRequestContext requestContext) throws IOException { + // requestContext.getUriInfo() should always be an instance of `ExtendedUriInfo` + // in the Jersey client + if (!(requestContext.getUriInfo() instanceof final ExtendedUriInfo uriInfo)) { + return; + } + + final RateLimitedByIp annotation = uriInfo.getMatchedResourceMethod() + .getInvocable() + .getHandlingMethod() + .getAnnotation(RateLimitedByIp.class); + + if (annotation == null) { + return; + } + + final RateLimiters.Handle handle = annotation.value(); + + try { + final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR); + final Optional maybeMostRecentProxy = Optional.ofNullable(xffHeader) + .flatMap(HeaderUtils::getMostRecentProxy); + + // checking if we failed to extract the most recent IP from the X-Forwarded-For header + // for any reason + if (maybeMostRecentProxy.isEmpty()) { + // checking if annotation is configured to fail when the most recent IP is not resolved + if (annotation.failOnUnresolvedIp()) { + logger.error("Missing/bad X-Forwarded-For: {}", xffHeader); + throw INVALID_HEADER_EXCEPTION; + } + // otherwise, allow request + return; + } + + final Optional maybeRateLimiter = rateLimiters.byHandle(handle); + if (maybeRateLimiter.isEmpty()) { + logger.warn("RateLimiter not found for {}. Make sure it's initialized in RateLimiters class", handle); + return; + } + + maybeRateLimiter.get().validate(maybeMostRecentProxy.get()); + } catch (RateLimitExceededException e) { + final Response response = EXCEPTION_MAPPER.toResponse(e); + throw new ClientErrorException(response); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIp.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIp.java new file mode 100644 index 000000000..29da64979 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIp.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface RateLimitedByIp { + + RateLimiters.Handle value(); + + boolean failOnUnresolvedIp() default true; +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 4485bd205..c79209f8e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -1,15 +1,41 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.limits; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.commons.lang3.tuple.Pair; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; public class RateLimiters { + public enum Handle { + USERNAME_LOOKUP("usernameLookup"), + CHECK_ACCOUNT_EXISTENCE("checkAccountExistence"), + BACKUP_AUTH_CHECK; + + private final String id; + + + Handle(final String id) { + this.id = id; + } + + Handle() { + this.id = name(); + } + + public String id() { + return id; + } + } + private final RateLimiter smsDestinationLimiter; private final RateLimiter voiceDestinationLimiter; private final RateLimiter voiceDestinationDailyLimiter; @@ -17,114 +43,51 @@ public class RateLimiters { private final RateLimiter smsVoicePrefixLimiter; private final RateLimiter verifyLimiter; private final RateLimiter pinLimiter; - private final RateLimiter attachmentLimiter; private final RateLimiter preKeysLimiter; private final RateLimiter messagesLimiter; - private final RateLimiter allocateDeviceLimiter; private final RateLimiter verifyDeviceLimiter; - private final RateLimiter turnLimiter; - private final RateLimiter profileLimiter; private final RateLimiter stickerPackLimiter; - private final RateLimiter artPackLimiter; - private final RateLimiter usernameLookupLimiter; private final RateLimiter usernameSetLimiter; - private final RateLimiter usernameReserveLimiter; - - private final RateLimiter checkAccountExistenceLimiter; - private final RateLimiter storiesLimiter; - public RateLimiters(RateLimitsConfiguration config, FaultTolerantRedisCluster cacheCluster) { - this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination", - config.getSmsDestination().getBucketSize(), - config.getSmsDestination().getLeakRatePerMinute()); + private final Map rateLimiterByHandle; - this.voiceDestinationLimiter = new RateLimiter(cacheCluster, "voxDestination", - config.getVoiceDestination().getBucketSize(), - config.getVoiceDestination().getLeakRatePerMinute()); + public RateLimiters(final RateLimitsConfiguration config, final FaultTolerantRedisCluster cacheCluster) { + this.smsDestinationLimiter = fromConfig("smsDestination", config.getSmsDestination(), cacheCluster); + this.voiceDestinationLimiter = fromConfig("voxDestination", config.getVoiceDestination(), cacheCluster); + this.voiceDestinationDailyLimiter = fromConfig("voxDestinationDaily", config.getVoiceDestinationDaily(), cacheCluster); + this.smsVoiceIpLimiter = fromConfig("smsVoiceIp", config.getSmsVoiceIp(), cacheCluster); + this.smsVoicePrefixLimiter = fromConfig("smsVoicePrefix", config.getSmsVoicePrefix(), cacheCluster); + this.verifyLimiter = fromConfig("verify", config.getVerifyNumber(), cacheCluster); + this.pinLimiter = fromConfig("pin", config.getVerifyPin(), cacheCluster); + this.attachmentLimiter = fromConfig("attachmentCreate", config.getAttachments(), cacheCluster); + this.preKeysLimiter = fromConfig("prekeys", config.getPreKeys(), cacheCluster); + this.messagesLimiter = fromConfig("messages", config.getMessages(), cacheCluster); + this.allocateDeviceLimiter = fromConfig("allocateDevice", config.getAllocateDevice(), cacheCluster); + this.verifyDeviceLimiter = fromConfig("verifyDevice", config.getVerifyDevice(), cacheCluster); + this.turnLimiter = fromConfig("turnAllocate", config.getTurnAllocations(), cacheCluster); + this.profileLimiter = fromConfig("profile", config.getProfile(), cacheCluster); + this.stickerPackLimiter = fromConfig("stickerPack", config.getStickerPack(), cacheCluster); + this.artPackLimiter = fromConfig("artPack", config.getArtPack(), cacheCluster); + this.usernameSetLimiter = fromConfig("usernameSet", config.getUsernameSet(), cacheCluster); + this.usernameReserveLimiter = fromConfig("usernameReserve", config.getUsernameReserve(), cacheCluster); + this.storiesLimiter = fromConfig("stories", config.getStories(), cacheCluster); - this.voiceDestinationDailyLimiter = new RateLimiter(cacheCluster, "voxDestinationDaily", - config.getVoiceDestinationDaily().getBucketSize(), - config.getVoiceDestinationDaily().getLeakRatePerMinute()); + this.rateLimiterByHandle = Stream.of( + fromConfig(Handle.BACKUP_AUTH_CHECK.id(), config.getBackupAuthCheck(), cacheCluster), + fromConfig(Handle.CHECK_ACCOUNT_EXISTENCE.id(), config.getCheckAccountExistence(), cacheCluster), + fromConfig(Handle.USERNAME_LOOKUP.id(), config.getUsernameLookup(), cacheCluster) + ).map(rl -> Pair.of(rl.name, rl)).collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + } - this.smsVoiceIpLimiter = new RateLimiter(cacheCluster, "smsVoiceIp", - config.getSmsVoiceIp().getBucketSize(), - config.getSmsVoiceIp().getLeakRatePerMinute()); - - this.smsVoicePrefixLimiter = new RateLimiter(cacheCluster, "smsVoicePrefix", - config.getSmsVoicePrefix().getBucketSize(), - config.getSmsVoicePrefix().getLeakRatePerMinute()); - - this.verifyLimiter = new LockingRateLimiter(cacheCluster, "verify", - config.getVerifyNumber().getBucketSize(), - config.getVerifyNumber().getLeakRatePerMinute()); - - this.pinLimiter = new LockingRateLimiter(cacheCluster, "pin", - config.getVerifyPin().getBucketSize(), - config.getVerifyPin().getLeakRatePerMinute()); - - this.attachmentLimiter = new RateLimiter(cacheCluster, "attachmentCreate", - config.getAttachments().getBucketSize(), - config.getAttachments().getLeakRatePerMinute()); - - this.preKeysLimiter = new RateLimiter(cacheCluster, "prekeys", - config.getPreKeys().getBucketSize(), - config.getPreKeys().getLeakRatePerMinute()); - - this.messagesLimiter = new RateLimiter(cacheCluster, "messages", - config.getMessages().getBucketSize(), - config.getMessages().getLeakRatePerMinute()); - - this.allocateDeviceLimiter = new RateLimiter(cacheCluster, "allocateDevice", - config.getAllocateDevice().getBucketSize(), - config.getAllocateDevice().getLeakRatePerMinute()); - - this.verifyDeviceLimiter = new RateLimiter(cacheCluster, "verifyDevice", - config.getVerifyDevice().getBucketSize(), - config.getVerifyDevice().getLeakRatePerMinute()); - - this.turnLimiter = new RateLimiter(cacheCluster, "turnAllocate", - config.getTurnAllocations().getBucketSize(), - config.getTurnAllocations().getLeakRatePerMinute()); - - this.profileLimiter = new RateLimiter(cacheCluster, "profile", - config.getProfile().getBucketSize(), - config.getProfile().getLeakRatePerMinute()); - - this.stickerPackLimiter = new RateLimiter(cacheCluster, "stickerPack", - config.getStickerPack().getBucketSize(), - config.getStickerPack().getLeakRatePerMinute()); - - this.artPackLimiter = new RateLimiter(cacheCluster, "artPack", - config.getArtPack().getBucketSize(), - config.getArtPack().getLeakRatePerMinute()); - - this.usernameLookupLimiter = new RateLimiter(cacheCluster, "usernameLookup", - config.getUsernameLookup().getBucketSize(), - config.getUsernameLookup().getLeakRatePerMinute()); - - this.usernameSetLimiter = new RateLimiter(cacheCluster, "usernameSet", - config.getUsernameSet().getBucketSize(), - config.getUsernameSet().getLeakRatePerMinute()); - - this.usernameReserveLimiter = new RateLimiter(cacheCluster, "usernameReserve", - config.getUsernameReserve().getBucketSize(), - config.getUsernameReserve().getLeakRatePerMinute()); - - - this.checkAccountExistenceLimiter = new RateLimiter(cacheCluster, "checkAccountExistence", - config.getCheckAccountExistence().getBucketSize(), - config.getCheckAccountExistence().getLeakRatePerMinute()); - - this.storiesLimiter = new RateLimiter(cacheCluster, "stories", - config.getStories().getBucketSize(), - config.getStories().getLeakRatePerMinute()); + public Optional byHandle(final Handle handle) { + return Optional.ofNullable(rateLimiterByHandle.get(handle.id())); } public RateLimiter getAllocateDeviceLimiter() { @@ -192,7 +155,7 @@ public class RateLimiters { } public RateLimiter getUsernameLookupLimiter() { - return usernameLookupLimiter; + return byHandle(Handle.USERNAME_LOOKUP).orElseThrow(); } public RateLimiter getUsernameSetLimiter() { @@ -204,8 +167,17 @@ public class RateLimiters { } public RateLimiter getCheckAccountExistenceLimiter() { - return checkAccountExistenceLimiter; + return byHandle(Handle.CHECK_ACCOUNT_EXISTENCE).orElseThrow(); } - public RateLimiter getStoriesLimiter() { return storiesLimiter; } + public RateLimiter getStoriesLimiter() { + return storiesLimiter; + } + + private static RateLimiter fromConfig( + final String name, + final RateLimitsConfiguration.RateLimitConfiguration cfg, + final FaultTolerantRedisCluster cacheCluster) { + return new RateLimiter(cacheCluster, name, cfg.getBucketSize(), cfg.getLeakRatePerMinute()); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/E164.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/E164.java new file mode 100644 index 000000000..8313aed84 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/E164.java @@ -0,0 +1,56 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Documented; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.Objects; +import javax.validation.Constraint; +import javax.validation.ConstraintValidator; +import javax.validation.ConstraintValidatorContext; +import javax.validation.Payload; + +/** + * Constraint annotation that requires annotated entity + * to hold (or return) a string value that is a valid E164-normalized phone number. + */ +@Target({ FIELD, PARAMETER, METHOD }) +@Retention(RUNTIME) +@Constraint(validatedBy = E164.Validator.class) +@Documented +public @interface E164 { + + String message() default "{org.whispersystems.textsecuregcm.util.E164.message}"; + + Class[] groups() default { }; + + Class[] payload() default { }; + + class Validator implements ConstraintValidator { + + @Override + public boolean isValid(final String value, final ConstraintValidatorContext context) { + if (Objects.isNull(value)) { + return true; + } + if (!value.startsWith("+")) { + return false; + } + try { + Util.requireNormalizedNumber(value); + } catch (final ImpossiblePhoneNumberException | NonNormalizedPhoneNumberException e) { + return false; + } + return true; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/HmacUtils.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/HmacUtils.java index 89951f204..6469fe01b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/HmacUtils.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/HmacUtils.java @@ -6,11 +6,12 @@ package org.whispersystems.textsecuregcm.util; import java.nio.charset.StandardCharsets; -import javax.crypto.Mac; -import javax.crypto.spec.SecretKeySpec; import java.security.InvalidKeyException; +import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.HexFormat; +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; public final class HmacUtils { @@ -63,4 +64,14 @@ public final class HmacUtils { public static String hmac256TruncatedToHexString(final byte[] key, final String input, final int length) { return hmac256TruncatedToHexString(key, input.getBytes(StandardCharsets.UTF_8), length); } + + public static boolean hmacHexStringsEqual(final String expectedAsHexString, final String actualAsHexString) { + try { + final byte[] aBytes = HEX.parseHex(expectedAsHexString); + final byte[] bBytes = HEX.parseHex(actualAsHexString); + return MessageDigest.isEqual(aBytes, bBytes); + } catch (final IllegalArgumentException e) { + return false; + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java index e05340970..3f1d91d7c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java @@ -7,35 +7,48 @@ package org.whispersystems.textsecuregcm.util; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; +import java.util.Optional; import java.util.UUID; -public class UUIDUtil { +public final class UUIDUtil { - public static byte[] toBytes(final UUID uuid) { - return toByteBuffer(uuid).array(); - } + private UUIDUtil() { + // utility class + } - public static ByteBuffer toByteBuffer(final UUID uuid) { - final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); - byteBuffer.putLong(uuid.getMostSignificantBits()); - byteBuffer.putLong(uuid.getLeastSignificantBits()); - return byteBuffer.flip(); - } + public static byte[] toBytes(final UUID uuid) { + return toByteBuffer(uuid).array(); + } - public static UUID fromBytes(final byte[] bytes) { - return fromByteBuffer(ByteBuffer.wrap(bytes)); - } + public static ByteBuffer toByteBuffer(final UUID uuid) { + final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); + byteBuffer.putLong(uuid.getMostSignificantBits()); + byteBuffer.putLong(uuid.getLeastSignificantBits()); + return byteBuffer.flip(); + } - public static UUID fromByteBuffer(final ByteBuffer byteBuffer) { - try { - final long mostSigBits = byteBuffer.getLong(); - final long leastSigBits = byteBuffer.getLong(); - if (byteBuffer.hasRemaining()) { - throw new IllegalArgumentException("unexpected byte array length; was greater than 16"); - } - return new UUID(mostSigBits, leastSigBits); - } catch (BufferUnderflowException e) { - throw new IllegalArgumentException("unexpected byte array length; was less than 16"); + public static UUID fromBytes(final byte[] bytes) { + return fromByteBuffer(ByteBuffer.wrap(bytes)); + } + + public static UUID fromByteBuffer(final ByteBuffer byteBuffer) { + try { + final long mostSigBits = byteBuffer.getLong(); + final long leastSigBits = byteBuffer.getLong(); + if (byteBuffer.hasRemaining()) { + throw new IllegalArgumentException("unexpected byte array length; was greater than 16"); } + return new UUID(mostSigBits, leastSigBits); + } catch (BufferUnderflowException e) { + throw new IllegalArgumentException("unexpected byte array length; was less than 16"); + } + } + + public static Optional fromStringSafe(final String uuidString) { + try { + return Optional.of(UUID.fromString(uuidString)); + } catch (final IllegalArgumentException e) { + return Optional.empty(); + } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 0593296f3..3bbfc2682 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -89,6 +89,7 @@ import org.whispersystems.textsecuregcm.entities.ReserveUsernameResponse; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.UsernameRequest; import org.whispersystems.textsecuregcm.entities.UsernameResponse; +import org.whispersystems.textsecuregcm.limits.RateLimitByIpFilter; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper; @@ -112,7 +113,7 @@ import org.whispersystems.textsecuregcm.storage.UsernameReservationNotFoundExcep import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.Hex; -import org.whispersystems.textsecuregcm.util.MockHelper; +import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestClock; @@ -172,7 +173,7 @@ class AccountControllerTest { private byte[] registration_lock_key = new byte[32]; - private static final SecureStorageServiceConfiguration STORAGE_CFG = MockHelper.buildMock( + private static final SecureStorageServiceConfiguration STORAGE_CFG = MockUtils.buildMock( SecureStorageServiceConfiguration.class, cfg -> when(cfg.decodeUserAuthenticationTokenSharedSecret()).thenReturn(new byte[32])); @@ -188,6 +189,7 @@ class AccountControllerTest { .addProvider(new RateLimitExceededExceptionMapper()) .addProvider(new ImpossiblePhoneNumberExceptionMapper()) .addProvider(new NonNormalizedPhoneNumberExceptionMapper()) + .addProvider(new RateLimitByIpFilter(rateLimiters)) .setMapper(SystemMapper.getMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new AccountController(pendingAccountsManager, @@ -1949,13 +1951,13 @@ class AccountControllerTest { @Test void testAccountExistsRateLimited() throws RateLimitExceededException { + final Duration expectedRetryAfter = Duration.ofSeconds(13); final Account account = mock(Account.class); final UUID accountIdentifier = UUID.randomUUID(); when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account)); - final RateLimiter checkAccountLimiter = mock(RateLimiter.class); - when(rateLimiters.getCheckAccountExistenceLimiter()).thenReturn(checkAccountLimiter); - doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(checkAccountLimiter).validate("127.0.0.1"); + MockUtils.updateRateLimiterResponseToFail( + rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter); final Response response = resources.getJerseyTest() .target(String.format("/v1/accounts/account/%s", accountIdentifier)) @@ -1964,7 +1966,7 @@ class AccountControllerTest { .head(); assertThat(response.getStatus()).isEqualTo(413); - assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(Duration.ofSeconds(13).toSeconds())); + assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(expectedRetryAfter.toSeconds())); } @Test @@ -2018,7 +2020,9 @@ class AccountControllerTest { @Test void testLookupUsernameRateLimited() throws RateLimitExceededException { - doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(usernameLookupLimiter).validate("127.0.0.1"); + final Duration expectedRetryAfter = Duration.ofSeconds(13); + MockUtils.updateRateLimiterResponseToFail( + rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter); final Response response = resources.getJerseyTest() .target("/v1/accounts/username/test.123") .request() @@ -2026,7 +2030,7 @@ class AccountControllerTest { .get(); assertThat(response.getStatus()).isEqualTo(413); - assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(Duration.ofSeconds(13).toSeconds())); + assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(expectedRetryAfter.toSeconds())); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SecureBackupControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SecureBackupControllerTest.java new file mode 100644 index 000000000..46bd69cd3 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/SecureBackupControllerTest.java @@ -0,0 +1,291 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import org.apache.commons.lang3.RandomUtils; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; +import org.whispersystems.textsecuregcm.configuration.SecureBackupServiceConfiguration; +import org.whispersystems.textsecuregcm.entities.AuthCheckRequest; +import org.whispersystems.textsecuregcm.entities.AuthCheckResponse; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.MockUtils; +import org.whispersystems.textsecuregcm.util.MutableClock; +import org.whispersystems.textsecuregcm.util.SystemMapper; + +@ExtendWith(DropwizardExtensionsSupport.class) +class SecureBackupControllerTest { + + private static final UUID USER_1 = UUID.randomUUID(); + + private static final UUID USER_2 = UUID.randomUUID(); + + private static final UUID USER_3 = UUID.randomUUID(); + + private static final String E164_VALID = "+18005550123"; + + private static final String E164_INVALID = "1(800)555-0123"; + + private static final byte[] SECRET = RandomUtils.nextBytes(32); + + private static final SecureBackupServiceConfiguration CFG = MockUtils.buildMock( + SecureBackupServiceConfiguration.class, + cfg -> Mockito.when(cfg.getUserAuthenticationTokenSharedSecret()).thenReturn(SECRET) + ); + + private static final MutableClock CLOCK = MockUtils.mutableClock(0); + + private static final ExternalServiceCredentialsGenerator CREDENTIAL_GENERATOR = + SecureBackupController.credentialsGenerator(CFG, CLOCK); + + private static final AccountsManager ACCOUNTS_MANAGER = Mockito.mock(AccountsManager.class); + + private static final SecureBackupController CONTROLLER = + new SecureBackupController(CREDENTIAL_GENERATOR, ACCOUNTS_MANAGER); + + private static final ResourceExtension RESOURCES = ResourceExtension.builder() + .addProvider(AuthHelper.getAuthFilter()) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(CONTROLLER) + .build(); + + @BeforeAll + public static void before() throws Exception { + Mockito.when(ACCOUNTS_MANAGER.getByE164(E164_VALID)).thenReturn(Optional.of(account(USER_1))); + } + + @Test + public void testOneMatch() throws Exception { + validate(Map.of( + token(USER_1, day(1)), AuthCheckResponse.Result.MATCH, + token(USER_2, day(1)), AuthCheckResponse.Result.NO_MATCH, + token(USER_3, day(1)), AuthCheckResponse.Result.NO_MATCH + ), day(2)); + } + + @Test + public void testNoMatch() throws Exception { + validate(Map.of( + token(USER_2, day(1)), AuthCheckResponse.Result.NO_MATCH, + token(USER_3, day(1)), AuthCheckResponse.Result.NO_MATCH + ), day(2)); + } + + @Test + public void testEmptyInput() throws Exception { + validate(Collections.emptyMap(), day(2)); + } + + @Test + public void testSomeInvalid() throws Exception { + final String fakeToken = token(USER_3, day(1)).replaceAll(USER_3.toString(), USER_2.toString()); + validate(Map.of( + token(USER_1, day(1)), AuthCheckResponse.Result.MATCH, + token(USER_2, day(1)), AuthCheckResponse.Result.NO_MATCH, + fakeToken, AuthCheckResponse.Result.INVALID + ), day(2)); + } + + @Test + public void testSomeExpired() throws Exception { + validate(Map.of( + token(USER_1, day(100)), AuthCheckResponse.Result.MATCH, + token(USER_2, day(100)), AuthCheckResponse.Result.NO_MATCH, + token(USER_3, day(10)), AuthCheckResponse.Result.INVALID, + token(USER_3, day(20)), AuthCheckResponse.Result.INVALID + ), day(110)); + } + + @Test + public void testSomeHaveNewerVersions() throws Exception { + validate(Map.of( + token(USER_1, day(10)), AuthCheckResponse.Result.INVALID, + token(USER_1, day(20)), AuthCheckResponse.Result.MATCH, + token(USER_2, day(10)), AuthCheckResponse.Result.NO_MATCH, + token(USER_3, day(20)), AuthCheckResponse.Result.NO_MATCH, + token(USER_3, day(10)), AuthCheckResponse.Result.INVALID + ), day(25)); + } + + private static void validate( + final Map expected, + final long nowMillis) throws Exception { + CLOCK.setTimeMillis(nowMillis); + final AuthCheckRequest request = new AuthCheckRequest(E164_VALID, List.copyOf(expected.keySet())); + final AuthCheckResponse response = CONTROLLER.authCheck(request); + assertEquals(expected, response.matches()); + } + + @Test + public void testHttpResponseCodeSuccess() throws Exception { + final Map expected = Map.of( + token(USER_1, day(10)), AuthCheckResponse.Result.INVALID, + token(USER_1, day(20)), AuthCheckResponse.Result.MATCH, + token(USER_2, day(10)), AuthCheckResponse.Result.NO_MATCH, + token(USER_3, day(20)), AuthCheckResponse.Result.NO_MATCH, + token(USER_3, day(10)), AuthCheckResponse.Result.INVALID + ); + + CLOCK.setTimeMillis(day(25)); + + final AuthCheckRequest in = new AuthCheckRequest(E164_VALID, List.copyOf(expected.keySet())); + + final Response response = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(in, MediaType.APPLICATION_JSON)); + + try (response) { + final AuthCheckResponse res = response.readEntity(AuthCheckResponse.class); + assertEquals(200, response.getStatus()); + assertEquals(expected, res.matches()); + } + } + + @Test + public void testHttpResponseCodeWhenInvalidNumber() throws Exception { + final AuthCheckRequest in = new AuthCheckRequest(E164_INVALID, Collections.singletonList("1")); + final Response response = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(in, MediaType.APPLICATION_JSON)); + + try (response) { + assertEquals(422, response.getStatus()); + } + } + + @Test + public void testHttpResponseCodeWhenTooManyTokens() throws Exception { + final AuthCheckRequest inOkay = new AuthCheckRequest(E164_VALID, List.of( + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10" + )); + final AuthCheckRequest inTooMany = new AuthCheckRequest(E164_VALID, List.of( + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11" + )); + final AuthCheckRequest inNoTokens = new AuthCheckRequest(E164_VALID, Collections.emptyList()); + + final Response responseOkay = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(inOkay, MediaType.APPLICATION_JSON)); + + final Response responseError1 = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(inTooMany, MediaType.APPLICATION_JSON)); + + final Response responseError2 = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(inNoTokens, MediaType.APPLICATION_JSON)); + + try (responseOkay; responseError1; responseError2) { + assertEquals(200, responseOkay.getStatus()); + assertEquals(422, responseError1.getStatus()); + assertEquals(422, responseError2.getStatus()); + } + } + + @Test + public void testHttpResponseCodeWhenPasswordsMissing() throws Exception { + final Response response = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(""" + { + "number": "123" + } + """, MediaType.APPLICATION_JSON)); + + try (response) { + assertEquals(422, response.getStatus()); + } + } + + @Test + public void testHttpResponseCodeWhenNumberMissing() throws Exception { + final Response response = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(""" + { + "passwords": ["aaa:bbb"] + } + """, MediaType.APPLICATION_JSON)); + + try (response) { + assertEquals(422, response.getStatus()); + } + } + + @Test + public void testHttpResponseCodeWhenExtraFields() throws Exception { + final Response response = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity(""" + { + "number": "+18005550123", + "passwords": ["aaa:bbb"], + "unexpected": "value" + } + """, MediaType.APPLICATION_JSON)); + + try (response) { + assertEquals(200, response.getStatus()); + } + } + + @Test + public void testHttpResponseCodeWhenNotAJson() throws Exception { + final Response response = RESOURCES.getJerseyTest() + .target("/v1/backup/auth/check") + .request() + .post(Entity.entity("random text", MediaType.APPLICATION_JSON)); + + try (response) { + assertEquals(400, response.getStatus()); + } + } + + private static String token(final UUID uuid, final long timeMillis) { + CLOCK.setTimeMillis(timeMillis); + final ExternalServiceCredentials credentials = CREDENTIAL_GENERATOR.generateForUuid(uuid); + return credentials.username() + ":" + credentials.password(); + } + + private static long day(final int n) { + return TimeUnit.DAYS.toMillis(n); + } + + private static Account account(final UUID uuid) { + final Account a = new Account(); + a.setUuid(uuid); + return a; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java new file mode 100644 index 000000000..be1268f68 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.common.net.HttpHeaders; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import java.time.Duration; +import java.util.Optional; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.core.Response; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.util.MockUtils; +import org.whispersystems.textsecuregcm.util.SystemMapper; + +@ExtendWith(DropwizardExtensionsSupport.class) +public class RateLimitedByIpTest { + + private static final String IP = "70.130.130.200"; + + private static final String VALID_X_FORWARDED_FOR = "1.1.1.1," + IP; + + private static final String INVALID_X_FORWARDED_FOR = "1.1.1.1,"; + + private static final Duration RETRY_AFTER = Duration.ofSeconds(100); + + private static final Duration RETRY_AFTER_INVALID_HEADER = RateLimitByIpFilter.INVALID_HEADER_EXCEPTION + .getRetryDuration() + .orElseThrow(); + + + @Path("/test") + public static class Controller { + @GET + @Path("/strict") + @RateLimitedByIp(RateLimiters.Handle.BACKUP_AUTH_CHECK) + public Response strict() { + return Response.ok().build(); + } + + @GET + @Path("/loose") + @RateLimitedByIp(value = RateLimiters.Handle.BACKUP_AUTH_CHECK, failOnUnresolvedIp = false) + public Response loose() { + return Response.ok().build(); + } + } + + private static final RateLimiter RATE_LIMITER = Mockito.mock(RateLimiter.class); + + private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rl -> + Mockito.when(rl.byHandle(Mockito.eq(RateLimiters.Handle.BACKUP_AUTH_CHECK))).thenReturn(Optional.of(RATE_LIMITER))); + + private static final ResourceExtension RESOURCES = ResourceExtension.builder() + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new Controller()) + .addProvider(new RateLimitByIpFilter(RATE_LIMITERS)) + .build(); + + @Test + public void testRateLimits() throws Exception { + Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); + validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); + Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP)); + validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); + Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); + validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); + Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP)); + validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); + } + + @Test + public void testInvalidHeader() throws Exception { + Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); + validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); + validateFailure("/test/strict", INVALID_X_FORWARDED_FOR, RETRY_AFTER_INVALID_HEADER); + validateFailure("/test/strict", "", RETRY_AFTER_INVALID_HEADER); + + validateSuccess("/test/loose", VALID_X_FORWARDED_FOR); + validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR); + validateSuccess("/test/loose", ""); + + // also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP + Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.anyString()); + validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER); + validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR); + validateSuccess("/test/loose", ""); + } + + private static void validateSuccess(final String path, final String xff) { + final Response response = RESOURCES.getJerseyTest() + .target(path) + .request() + .header(HttpHeaders.X_FORWARDED_FOR, xff) + .get(); + + assertEquals(200, response.getStatus()); + } + + private static void validateFailure(final String path, final String xff, final Duration expectedRetryAfter) { + final Response response = RESOURCES.getJerseyTest() + .target(path) + .request() + .header(HttpHeaders.X_FORWARDED_FOR, xff) + .get(); + + assertEquals(413, response.getStatus()); + assertEquals("" + expectedRetryAfter.getSeconds(), response.getHeaderString(HttpHeaders.RETRY_AFTER)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java index e8aa81cce..5281ea240 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java @@ -1,3 +1,8 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertAll; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/ExternalServiceCredentialsGeneratorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/ExternalServiceCredentialsGeneratorTest.java index a993ee88b..f492a5c1c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/ExternalServiceCredentialsGeneratorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/ExternalServiceCredentialsGeneratorTest.java @@ -5,24 +5,39 @@ package org.whispersystems.textsecuregcm.tests.auth; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; +import org.whispersystems.textsecuregcm.util.MockUtils; +import org.whispersystems.textsecuregcm.util.MutableClock; class ExternalServiceCredentialsGeneratorTest { + private static final String E164 = "+14152222222"; + + private static final long TIME_SECONDS = 12345; + + private static final long TIME_MILLIS = TimeUnit.SECONDS.toMillis(TIME_SECONDS); + + private static final String TIME_SECONDS_STRING = Long.toString(TIME_SECONDS); + + @Test void testGenerateDerivedUsername() { final ExternalServiceCredentialsGenerator generator = ExternalServiceCredentialsGenerator .builder(new byte[32]) .withUserDerivationKey(new byte[32]) .build(); - final ExternalServiceCredentials credentials = generator.generateFor("+14152222222"); - - assertThat(credentials.username()).isNotEqualTo("+14152222222"); - assertThat(credentials.password().startsWith("+14152222222")).isFalse(); + final ExternalServiceCredentials credentials = generator.generateFor(E164); + assertNotEquals(credentials.username(), E164); + assertFalse(credentials.password().startsWith(E164)); + assertEquals(credentials.password().split(":").length, 3); } @Test @@ -30,10 +45,71 @@ class ExternalServiceCredentialsGeneratorTest { final ExternalServiceCredentialsGenerator generator = ExternalServiceCredentialsGenerator .builder(new byte[32]) .build(); - final ExternalServiceCredentials credentials = generator.generateFor("+14152222222"); - - assertThat(credentials.username()).isEqualTo("+14152222222"); - assertThat(credentials.password().startsWith("+14152222222")).isTrue(); + final ExternalServiceCredentials credentials = generator.generateFor(E164); + assertEquals(credentials.username(), E164); + assertTrue(credentials.password().startsWith(E164)); + assertEquals(credentials.password().split(":").length, 3); } + @Test + public void testNotPrependUsername() throws Exception { + final MutableClock clock = MockUtils.mutableClock(TIME_MILLIS); + final ExternalServiceCredentialsGenerator generator = ExternalServiceCredentialsGenerator + .builder(new byte[32]) + .prependUsername(false) + .withClock(clock) + .build(); + final ExternalServiceCredentials credentials = generator.generateFor(E164); + assertEquals(credentials.username(), E164); + assertTrue(credentials.password().startsWith(TIME_SECONDS_STRING)); + assertEquals(credentials.password().split(":").length, 2); + } + + @Test + public void testValidateValid() throws Exception { + final MutableClock clock = MockUtils.mutableClock(TIME_MILLIS); + final ExternalServiceCredentialsGenerator generator = ExternalServiceCredentialsGenerator + .builder(new byte[32]) + .withClock(clock) + .build(); + final ExternalServiceCredentials credentials = generator.generateFor(E164); + assertEquals(generator.validateAndGetTimestamp(credentials).orElseThrow(), TIME_SECONDS); + } + + @Test + public void testValidateInvalid() throws Exception { + final MutableClock clock = MockUtils.mutableClock(TIME_MILLIS); + final ExternalServiceCredentialsGenerator generator = ExternalServiceCredentialsGenerator + .builder(new byte[32]) + .withClock(clock) + .build(); + final ExternalServiceCredentials credentials = generator.generateFor(E164); + + final ExternalServiceCredentials corruptedUsername = new ExternalServiceCredentials( + credentials.username(), credentials.password().replace(E164, E164 + "0")); + final ExternalServiceCredentials corruptedTimestamp = new ExternalServiceCredentials( + credentials.username(), credentials.password().replace(TIME_SECONDS_STRING, TIME_SECONDS_STRING + "0")); + final ExternalServiceCredentials corruptedPassword = new ExternalServiceCredentials( + credentials.username(), credentials.password() + "0"); + + assertTrue(generator.validateAndGetTimestamp(corruptedUsername).isEmpty()); + assertTrue(generator.validateAndGetTimestamp(corruptedTimestamp).isEmpty()); + assertTrue(generator.validateAndGetTimestamp(corruptedPassword).isEmpty()); + } + + @Test + public void testValidateWithExpiration() throws Exception { + final MutableClock clock = MockUtils.mutableClock(TIME_MILLIS); + final ExternalServiceCredentialsGenerator generator = ExternalServiceCredentialsGenerator + .builder(new byte[32]) + .withClock(clock) + .build(); + final ExternalServiceCredentials credentials = generator.generateFor(E164); + + final long elapsedSeconds = 10000; + clock.incrementSeconds(elapsedSeconds); + + assertEquals(generator.validateAndGetTimestamp(credentials, elapsedSeconds + 1).orElseThrow(), TIME_SECONDS); + assertTrue(generator.validateAndGetTimestamp(credentials, elapsedSeconds - 1).isEmpty()); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ArtControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ArtControllerTest.java index 7136657ff..c146ce8cc 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ArtControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ArtControllerTest.java @@ -26,13 +26,13 @@ import org.whispersystems.textsecuregcm.controllers.ArtController; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; -import org.whispersystems.textsecuregcm.util.MockHelper; +import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; @ExtendWith(DropwizardExtensionsSupport.class) class ArtControllerTest { - private static final ArtServiceConfiguration ART_SERVICE_CONFIGURATION = MockHelper.buildMock( + private static final ArtServiceConfiguration ART_SERVICE_CONFIGURATION = MockUtils.buildMock( ArtServiceConfiguration.class, cfg -> { Mockito.when(cfg.getUserAuthenticationTokenSharedSecret()).thenReturn(new byte[32]); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java index c95330788..a4883c82d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java @@ -23,13 +23,13 @@ import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; -import org.whispersystems.textsecuregcm.util.MockHelper; +import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; @ExtendWith(DropwizardExtensionsSupport.class) class SecureStorageControllerTest { - private static final SecureStorageServiceConfiguration STORAGE_CFG = MockHelper.buildMock( + private static final SecureStorageServiceConfiguration STORAGE_CFG = MockUtils.buildMock( SecureStorageServiceConfiguration.class, cfg -> when(cfg.decodeUserAuthenticationTokenSharedSecret()).thenReturn(new byte[32])); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/E164Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/E164Test.java new file mode 100644 index 000000000..8901e00df --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/E164Test.java @@ -0,0 +1,111 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.reflect.Method; +import java.util.Set; +import javax.validation.ConstraintViolation; +import javax.validation.Validation; +import javax.validation.Validator; +import org.junit.jupiter.api.Test; + +public class E164Test { + + private static final Validator VALIDATOR = Validation.buildDefaultValidatorFactory().getValidator(); + + private static final String E164_VALID = "+18005550123"; + + private static final String E164_INVALID = "1(800)555-0123"; + + private static final String EMPTY = ""; + + @SuppressWarnings("FieldCanBeLocal") + private static class Data { + + @E164 + private final String number; + + private Data(final String number) { + this.number = number; + } + } + + private static class Methods { + + public void foo(@E164 final String number) { + // noop + } + + @E164 + public String bar() { + return "nevermind"; + } + } + + private record Rec(@E164 String number) { + } + + @Test + public void testRecord() throws Exception { + checkNoViolations(new Rec(E164_VALID)); + checkHasViolations(new Rec(E164_INVALID)); + checkHasViolations(new Rec(EMPTY)); + } + + @Test + public void testClassField() throws Exception { + checkNoViolations(new Data(E164_VALID)); + checkHasViolations(new Data(E164_INVALID)); + checkHasViolations(new Data(EMPTY)); + } + + @Test + public void testParameters() throws Exception { + final Methods m = new Methods(); + final Method foo = Methods.class.getMethod("foo", String.class); + + final Set> violations1 = + VALIDATOR.forExecutables().validateParameters(m, foo, new Object[] {E164_VALID}); + final Set> violations2 = + VALIDATOR.forExecutables().validateParameters(m, foo, new Object[] {E164_INVALID}); + final Set> violations3 = + VALIDATOR.forExecutables().validateParameters(m, foo, new Object[] {EMPTY}); + + assertTrue(violations1.isEmpty()); + assertFalse(violations2.isEmpty()); + assertFalse(violations3.isEmpty()); + } + + @Test + public void testReturnValue() throws Exception { + final Methods m = new Methods(); + final Method bar = Methods.class.getMethod("bar"); + + final Set> violations1 = + VALIDATOR.forExecutables().validateReturnValue(m, bar, E164_VALID); + final Set> violations2 = + VALIDATOR.forExecutables().validateReturnValue(m, bar, E164_INVALID); + final Set> violations3 = + VALIDATOR.forExecutables().validateReturnValue(m, bar, EMPTY); + + assertTrue(violations1.isEmpty()); + assertFalse(violations2.isEmpty()); + assertFalse(violations3.isEmpty()); + } + + private static void checkNoViolations(final T object) { + final Set> violations = VALIDATOR.validate(object); + assertTrue(violations.isEmpty()); + } + + private static void checkHasViolations(final T object) { + final Set> violations = VALIDATOR.validate(object); + assertFalse(violations.isEmpty()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockHelper.java deleted file mode 100644 index 53c3eb0ea..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockHelper.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.util; - -import org.mockito.Mockito; - -public final class MockHelper { - - private MockHelper() { - // utility class - } - - @FunctionalInterface - public interface MockInitializer { - - void init(T mock) throws Exception; - } - - public static T buildMock(final Class clazz, final MockInitializer initializer) throws RuntimeException { - final T mock = Mockito.mock(clazz); - try { - initializer.init(mock); - } catch (Exception e) { - throw new RuntimeException(e); - } - return mock; - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java new file mode 100644 index 000000000..d39f43785 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java @@ -0,0 +1,72 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; + +import java.time.Duration; +import java.util.Optional; +import org.mockito.Mockito; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; + +public final class MockUtils { + + private MockUtils() { + // utility class + } + + @FunctionalInterface + public interface MockInitializer { + + void init(T mock) throws Exception; + } + + public static T buildMock(final Class clazz, final MockInitializer initializer) throws RuntimeException { + final T mock = Mockito.mock(clazz); + try { + initializer.init(mock); + } catch (Exception e) { + throw new RuntimeException(e); + } + return mock; + } + + public static MutableClock mutableClock(final long timeMillis) { + return new MutableClock(timeMillis); + } + + public static void updateRateLimiterResponseToAllow( + final RateLimiters rateLimitersMock, + final RateLimiters.Handle handle, + final String input) { + final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); + doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle)); + try { + doNothing().when(mockRateLimiter).validate(eq(input)); + } catch (final RateLimitExceededException e) { + throw new RuntimeException(e); + } + } + + public static void updateRateLimiterResponseToFail( + final RateLimiters rateLimitersMock, + final RateLimiters.Handle handle, + final String input, + final Duration retryAfter) { + final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); + doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle)); + try { + doThrow(new RateLimitExceededException(retryAfter)).when(mockRateLimiter).validate(eq(input)); + } catch (final RateLimitExceededException e) { + throw new RuntimeException(e); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MutableClock.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MutableClock.java new file mode 100644 index 000000000..5fae8f9ff --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MutableClock.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public class MutableClock extends Clock { + + private final AtomicReference delegate; + + + public MutableClock(final long timeMillis) { + this(fixedTimeMillis(timeMillis)); + } + + public MutableClock(final Clock clock) { + this.delegate = new AtomicReference<>(clock); + } + + public MutableClock() { + this(Clock.systemUTC()); + } + + public MutableClock setTimeMillis(final long timeMillis) { + delegate.set(fixedTimeMillis(timeMillis)); + return this; + } + + public MutableClock incrementMillis(final long incrementMillis) { + return increment(incrementMillis, TimeUnit.MILLISECONDS); + } + + public MutableClock incrementSeconds(final long incrementSeconds) { + return increment(incrementSeconds, TimeUnit.SECONDS); + } + + public MutableClock increment(final long increment, final TimeUnit timeUnit) { + final long current = delegate.get().instant().toEpochMilli(); + delegate.set(fixedTimeMillis(current + timeUnit.toMillis(increment))); + return this; + } + + @Override + public ZoneId getZone() { + return delegate.get().getZone(); + } + + @Override + public Clock withZone(final ZoneId zone) { + return delegate.get().withZone(zone); + } + + @Override + public Instant instant() { + return delegate.get().instant(); + } + + private static Clock fixedTimeMillis(final long timeMillis) { + return Clock.fixed(Instant.ofEpochMilli(timeMillis), ZoneId.of("Etc/UTC")); + } +}