diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java index a4bb7a2a1..051f5a612 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java @@ -14,7 +14,9 @@ import java.security.InvalidKeyException; import java.util.LinkedList; import java.util.List; import java.util.Optional; +import java.util.UUID; import javax.ws.rs.GET; +import javax.ws.rs.NotFoundException; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; @@ -70,7 +72,8 @@ public class CertificateController { @Path("/group/{startRedemptionTime}/{endRedemptionTime}") public GroupCredentials getAuthenticationCredentials(@Auth AuthenticatedAccount auth, @PathParam("startRedemptionTime") int startRedemptionTime, - @PathParam("endRedemptionTime") int endRedemptionTime) { + @PathParam("endRedemptionTime") int endRedemptionTime, + @QueryParam("identity") Optional identityType) { if (startRedemptionTime > endRedemptionTime) { throw new WebApplicationException(Response.Status.BAD_REQUEST); } @@ -83,10 +86,13 @@ public class CertificateController { List credentials = new LinkedList<>(); + final UUID identifier = identityType.map(String::toLowerCase).orElse("aci").equals("pni") ? + auth.getAccount().getPhoneNumberIdentifier().orElseThrow(NotFoundException::new) : + auth.getAccount().getUuid(); + for (int i = startRedemptionTime; i <= endRedemptionTime; i++) { credentials.add(new GroupCredentials.GroupCredential( - serverZkAuthOperations.issueAuthCredential(auth.getAccount().getUuid(), i) - .serialize(), + serverZkAuthOperations.issueAuthCredential(identifier, i).serialize(), i)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java index 0e042cec2..206fd9faf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java @@ -7,9 +7,11 @@ package org.whispersystems.textsecuregcm.tests.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; @@ -17,12 +19,15 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import java.io.IOException; import java.util.Base64; +import java.util.Optional; +import java.util.UUID; import javax.ws.rs.core.Response; import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.signal.zkgroup.ServerSecretParams; +import org.signal.zkgroup.VerificationFailedException; import org.signal.zkgroup.auth.AuthCredentialResponse; import org.signal.zkgroup.auth.ClientZkAuthOperations; import org.signal.zkgroup.auth.ServerZkAuthOperations; @@ -213,6 +218,41 @@ class CertificateControllerTest { .doesNotThrowAnyException(); } + @Test + void testGetSingleAuthCredentialByPni() { + when(AuthHelper.VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(Optional.of(UUID.randomUUID())); + + GroupCredentials credentials = resources.getJerseyTest() + .target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + Util.currentDaysSinceEpoch()) + .queryParam("identity", "pni") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(GroupCredentials.class); + + assertThat(credentials.getCredentials().size()).isEqualTo(1); + assertThat(credentials.getCredentials().get(0).getRedemptionTime()).isEqualTo(Util.currentDaysSinceEpoch()); + + ClientZkAuthOperations clientZkAuthOperations = new ClientZkAuthOperations(serverSecretParams.getPublicParams()); + + assertThatExceptionOfType(VerificationFailedException.class) + .isThrownBy(() -> + clientZkAuthOperations.receiveAuthCredential(AuthHelper.VALID_UUID, Util.currentDaysSinceEpoch(), new AuthCredentialResponse(credentials.getCredentials().get(0).getCredential()))); + } + + @Test + void testGetSingleAuthCredentialByPniNotSet() { + when(AuthHelper.VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(Optional.empty()); + + Response response = resources.getJerseyTest() + .target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + Util.currentDaysSinceEpoch()) + .queryParam("identity", "pni") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(); + + assertThat(response.getStatus()).isEqualTo(404); + } + @Test void testGetWeekLongAuthCredentials() { GroupCredentials credentials = resources.getJerseyTest()