Add tests for `ProvisioningController`
This commit is contained in:
parent
4fc3949367
commit
0f17d63774
|
@ -48,7 +48,7 @@ public class ProvisioningController {
|
||||||
rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
|
rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
|
||||||
|
|
||||||
if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0),
|
if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0),
|
||||||
Base64.getMimeDecoder().decode(message.getBody()))) {
|
Base64.getMimeDecoder().decode(message.body()))) {
|
||||||
throw new WebApplicationException(Response.Status.NOT_FOUND);
|
throw new WebApplicationException(Response.Status.NOT_FOUND);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,16 +5,7 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.entities;
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
import javax.validation.constraints.NotEmpty;
|
import javax.validation.constraints.NotEmpty;
|
||||||
|
|
||||||
public class ProvisioningMessage {
|
public record ProvisioningMessage(@NotEmpty String body) {
|
||||||
|
|
||||||
@JsonProperty
|
|
||||||
@NotEmpty
|
|
||||||
private String body;
|
|
||||||
|
|
||||||
public String getBody() {
|
|
||||||
return body;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
package org.whispersystems.textsecuregcm.controllers;
|
||||||
|
|
||||||
|
import com.google.common.collect.ImmutableSet;
|
||||||
|
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
||||||
|
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||||
|
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||||
|
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||||
|
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
||||||
|
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
|
||||||
|
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||||
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||||
|
|
||||||
|
import javax.ws.rs.client.Entity;
|
||||||
|
import javax.ws.rs.core.MediaType;
|
||||||
|
import javax.ws.rs.core.Response;
|
||||||
|
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.reset;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
|
class ProvisioningControllerTest {
|
||||||
|
|
||||||
|
private RateLimiter messagesRateLimiter;
|
||||||
|
|
||||||
|
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||||
|
private static final ProvisioningManager provisioningManager = mock(ProvisioningManager.class);
|
||||||
|
|
||||||
|
private static final ResourceExtension RESOURCE_EXTENSION = ResourceExtension.builder()
|
||||||
|
.addProvider(AuthHelper.getAuthFilter())
|
||||||
|
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
|
||||||
|
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
|
||||||
|
.addProvider(new RateLimitExceededExceptionMapper())
|
||||||
|
.setMapper(SystemMapper.getMapper())
|
||||||
|
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||||
|
.addResource(new ProvisioningController(rateLimiters, provisioningManager))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
reset(rateLimiters, provisioningManager);
|
||||||
|
|
||||||
|
messagesRateLimiter = mock(RateLimiter.class);
|
||||||
|
when(rateLimiters.getMessagesLimiter()).thenReturn(messagesRateLimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void sendProvisioningMessage() {
|
||||||
|
final String destination = UUID.randomUUID().toString();
|
||||||
|
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
when(provisioningManager.sendProvisioningMessage(any(), any())).thenReturn(true);
|
||||||
|
|
||||||
|
try (final Response response = RESOURCE_EXTENSION.getJerseyTest()
|
||||||
|
.target("/v1/provisioning/" + destination)
|
||||||
|
.request()
|
||||||
|
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||||
|
.put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)),
|
||||||
|
MediaType.APPLICATION_JSON))) {
|
||||||
|
|
||||||
|
assertEquals(Response.Status.NO_CONTENT.getStatusCode(), response.getStatus());
|
||||||
|
|
||||||
|
final ArgumentCaptor<ProvisioningAddress> provisioningAddressCaptor =
|
||||||
|
ArgumentCaptor.forClass(ProvisioningAddress.class);
|
||||||
|
|
||||||
|
final ArgumentCaptor<byte[]> provisioningMessageCaptor = ArgumentCaptor.forClass(byte[].class);
|
||||||
|
|
||||||
|
verify(provisioningManager).sendProvisioningMessage(provisioningAddressCaptor.capture(),
|
||||||
|
provisioningMessageCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(destination, provisioningAddressCaptor.getValue().getAddress());
|
||||||
|
assertEquals(0, provisioningAddressCaptor.getValue().getDeviceId());
|
||||||
|
|
||||||
|
assertArrayEquals(messageBody, provisioningMessageCaptor.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void sendProvisioningMessageRateLimited() throws RateLimitExceededException {
|
||||||
|
final String destination = UUID.randomUUID().toString();
|
||||||
|
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
doThrow(new RateLimitExceededException(Duration.ZERO))
|
||||||
|
.when(messagesRateLimiter).validate(AuthHelper.VALID_UUID);
|
||||||
|
|
||||||
|
try (final Response response = RESOURCE_EXTENSION.getJerseyTest()
|
||||||
|
.target("/v1/provisioning/" + destination)
|
||||||
|
.request()
|
||||||
|
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||||
|
.put(Entity.entity(new ProvisioningMessage(Base64.getMimeEncoder().encodeToString(messageBody)),
|
||||||
|
MediaType.APPLICATION_JSON))) {
|
||||||
|
|
||||||
|
assertEquals(413, response.getStatus());
|
||||||
|
|
||||||
|
verify(provisioningManager, never()).sendProvisioningMessage(any(), any());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue