From 0f17d63774bdf119501e2b1e75878fb191a2b45b Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Thu, 9 Feb 2023 12:04:52 -0500 Subject: [PATCH] Add tests for `ProvisioningController` --- .../controllers/ProvisioningController.java | 2 +- .../entities/ProvisioningMessage.java | 11 +- .../ProvisioningControllerTest.java | 119 ++++++++++++++++++ 3 files changed, 121 insertions(+), 11 deletions(-) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java index 6310ded13..a941d6dbd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java @@ -48,7 +48,7 @@ public class ProvisioningController { rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid()); if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0), - Base64.getMimeDecoder().decode(message.getBody()))) { + Base64.getMimeDecoder().decode(message.body()))) { throw new WebApplicationException(Response.Status.NOT_FOUND); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ProvisioningMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ProvisioningMessage.java index 09c857a16..b6c6dff8b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ProvisioningMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ProvisioningMessage.java @@ -5,16 +5,7 @@ package org.whispersystems.textsecuregcm.entities; -import com.fasterxml.jackson.annotation.JsonProperty; import javax.validation.constraints.NotEmpty; -public class ProvisioningMessage { - - @JsonProperty - @NotEmpty - private String body; - - public String getBody() { - return body; - } +public record ProvisioningMessage(@NotEmpty String body) { } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java new file mode 100644 index 000000000..31c6c7ad8 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java @@ -0,0 +1,119 @@ +package org.whispersystems.textsecuregcm.controllers; + +import com.google.common.collect.ImmutableSet; +import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; +import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.push.ProvisioningManager; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; + +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Base64; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(DropwizardExtensionsSupport.class) +class ProvisioningControllerTest { + + private RateLimiter messagesRateLimiter; + + private static final RateLimiters rateLimiters = mock(RateLimiters.class); + private static final ProvisioningManager provisioningManager = mock(ProvisioningManager.class); + + private static final ResourceExtension RESOURCE_EXTENSION = ResourceExtension.builder() + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .addProvider(new RateLimitExceededExceptionMapper()) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new ProvisioningController(rateLimiters, provisioningManager)) + .build(); + + @BeforeEach + void setUp() { + reset(rateLimiters, provisioningManager); + + messagesRateLimiter = mock(RateLimiter.class); + when(rateLimiters.getMessagesLimiter()).thenReturn(messagesRateLimiter); + } + + @Test + void sendProvisioningMessage() { + final String destination = UUID.randomUUID().toString(); + final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); + + when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true); + + try (final Response response = RESOURCE_EXTENSION.getJerseyTest() + .target("/v1/provisioning/" + destination) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)), + MediaType.APPLICATION_JSON))) { + + assertEquals(Response.Status.NO_CONTENT.getStatusCode(), response.getStatus()); + + final ArgumentCaptor provisioningAddressCaptor = + ArgumentCaptor.forClass(ProvisioningAddress.class); + + final ArgumentCaptor provisioningMessageCaptor = ArgumentCaptor.forClass(byte[].class); + + verify(provisioningManager).sendProvisioningMessage(provisioningAddressCaptor.capture(), + provisioningMessageCaptor.capture()); + + assertEquals(destination, provisioningAddressCaptor.getValue().getAddress()); + assertEquals(0, provisioningAddressCaptor.getValue().getDeviceId()); + + assertArrayEquals(messageBody, provisioningMessageCaptor.getValue()); + } + } + + @Test + void sendProvisioningMessageRateLimited() throws RateLimitExceededException { + final String destination = UUID.randomUUID().toString(); + final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); + + doThrow(new RateLimitExceededException(Duration.ZERO)) + .when(messagesRateLimiter).validate(AuthHelper.VALID_UUID); + + try (final Response response = RESOURCE_EXTENSION.getJerseyTest() + .target("/v1/provisioning/" + destination) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)), + MediaType.APPLICATION_JSON))) { + + assertEquals(413, response.getStatus()); + + verify(provisioningManager, never()).sendProvisioningMessage(any(), any()); + } + } +}