Add tests for `ProvisioningController`

This commit is contained in:
Jon Chambers 2023-02-09 12:04:52 -05:00 committed by GitHub
parent 4fc3949367
commit 0f17d63774
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 121 additions and 11 deletions

View File

@ -48,7 +48,7 @@ public class ProvisioningController {
rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0),
Base64.getMimeDecoder().decode(message.getBody()))) {
Base64.getMimeDecoder().decode(message.body()))) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
}

View File

@ -5,16 +5,7 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.validation.constraints.NotEmpty;
public class ProvisioningMessage {
@JsonProperty
@NotEmpty
private String body;
public String getBody() {
return body;
}
public record ProvisioningMessage(@NotEmpty String body) {
}

View File

@ -0,0 +1,119 @@
package org.whispersystems.textsecuregcm.controllers;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(DropwizardExtensionsSupport.class)
class ProvisioningControllerTest {
private RateLimiter messagesRateLimiter;
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final ProvisioningManager provisioningManager = mock(ProvisioningManager.class);
private static final ResourceExtension RESOURCE_EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProvisioningController(rateLimiters, provisioningManager))
.build();
@BeforeEach
void setUp() {
reset(rateLimiters, provisioningManager);
messagesRateLimiter = mock(RateLimiter.class);
when(rateLimiters.getMessagesLimiter()).thenReturn(messagesRateLimiter);
}
@Test
void sendProvisioningMessage() {
final String destination = UUID.randomUUID().toString();
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true);
try (final Response response = RESOURCE_EXTENSION.getJerseyTest()
.target("/v1/provisioning/" + destination)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)),
MediaType.APPLICATION_JSON))) {
assertEquals(Response.Status.NO_CONTENT.getStatusCode(), response.getStatus());
final ArgumentCaptor<ProvisioningAddress> provisioningAddressCaptor =
ArgumentCaptor.forClass(ProvisioningAddress.class);
final ArgumentCaptor<byte[]> provisioningMessageCaptor = ArgumentCaptor.forClass(byte[].class);
verify(provisioningManager).sendProvisioningMessage(provisioningAddressCaptor.capture(),
provisioningMessageCaptor.capture());
assertEquals(destination, provisioningAddressCaptor.getValue().getAddress());
assertEquals(0, provisioningAddressCaptor.getValue().getDeviceId());
assertArrayEquals(messageBody, provisioningMessageCaptor.getValue());
}
}
@Test
void sendProvisioningMessageRateLimited() throws RateLimitExceededException {
final String destination = UUID.randomUUID().toString();
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
doThrow(new RateLimitExceededException(Duration.ZERO))
.when(messagesRateLimiter).validate(AuthHelper.VALID_UUID);
try (final Response response = RESOURCE_EXTENSION.getJerseyTest()
.target("/v1/provisioning/" + destination)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)),
MediaType.APPLICATION_JSON))) {
assertEquals(413, response.getStatus());
verify(provisioningManager, never()).sendProvisioningMessage(any(), any());
}
}
}