diff --git a/integration-tests/src/test/java/org/signal/integration/MessagingTest.java b/integration-tests/src/test/java/org/signal/integration/MessagingTest.java index 45ff46593..cf32b3a8a 100644 --- a/integration-tests/src/test/java/org/signal/integration/MessagingTest.java +++ b/integration-tests/src/test/java/org/signal/integration/MessagingTest.java @@ -8,7 +8,6 @@ package org.signal.integration; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.List; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; @@ -27,11 +26,10 @@ public class MessagingTest { try { final byte[] expectedContent = "Hello, World!".getBytes(StandardCharsets.UTF_8); - final String contentBase64 = Base64.getEncoder().encodeToString(expectedContent); - final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), contentBase64); + final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), expectedContent); final IncomingMessageList messages = new IncomingMessageList(List.of(message), false, true, System.currentTimeMillis()); - final Pair sendMessage = Operations + Operations .apiPut("/v1/messages/%s".formatted(userB.aciUuid().toString()), messages) .authorized(userA) .execute(SendMessageResponse.class); @@ -40,7 +38,7 @@ public class MessagingTest { .authorized(userB) .execute(OutgoingMessageEntityList.class); - final byte[] actualContent = receiveMessages.getRight().messages().get(0).content(); + final byte[] actualContent = receiveMessages.getRight().messages().getFirst().content(); assertArrayEquals(expectedContent, actualContent); } finally { Operations.deleteUser(userA); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index a26ef2a86..85ad5c493 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -62,7 +62,6 @@ import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nullable; -import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.server.ManagedAsync; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.ServiceId; @@ -324,7 +323,7 @@ public class MessageController { int totalContentLength = 0; for (final IncomingMessage message : messages.messages()) { - final int contentLength = decodedSize(message.content()); + final int contentLength = message.content() != null ? message.content().length : 0; validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent); @@ -955,19 +954,4 @@ public class MessageController { .increment(); } } - - @VisibleForTesting - static int decodedSize(final String base64) { - final int padding; - - if (StringUtils.endsWith(base64, "==")) { - padding = 2; - } else if (StringUtils.endsWith(base64, "=")) { - padding = 1; - } else { - padding = 0; - } - - return ((StringUtils.length(base64) - padding) * 3) / 4; - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index 5ade8597d..f78cdcc37 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -4,18 +4,25 @@ */ package org.whispersystems.textsecuregcm.entities; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.protobuf.ByteString; +import com.webauthn4j.converter.jackson.deserializer.json.ByteArrayBase64Deserializer; import io.micrometer.core.instrument.Metrics; import jakarta.validation.constraints.AssertTrue; -import java.util.Base64; import javax.annotation.Nullable; -import org.apache.commons.lang3.StringUtils; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.storage.Account; +import java.util.Arrays; +import java.util.Objects; -public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) { +public record IncomingMessage(int type, + byte destinationDeviceId, + int destinationRegistrationId, + + @JsonDeserialize(using = ByteArrayBase64Deserializer.class) + byte[] content) { private static final String REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME = MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType"); @@ -50,8 +57,8 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)); } - if (StringUtils.isNotEmpty(content())) { - envelopeBuilder.setContent(ByteString.copyFrom(Base64.getDecoder().decode(content()))); + if (content() != null && content().length > 0) { + envelopeBuilder.setContent(ByteString.copyFrom(content())); } return envelopeBuilder.build(); @@ -69,4 +76,17 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio return true; } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof IncomingMessage(int otherType, byte otherDeviceId, int otherRegistrationId, byte[] otherContent))) + return false; + return type == otherType && destinationDeviceId == otherDeviceId + && destinationRegistrationId == otherRegistrationId && Objects.deepEquals(content, otherContent); + } + + @Override + public int hashCode() { + return Objects.hash(type, destinationDeviceId, destinationRegistrationId, Arrays.hashCode(content)); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index d0e7da979..9a0e4e5a7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -13,7 +13,6 @@ import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; -import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -138,16 +137,11 @@ public class ChangeNumberManager { } private static Optional getMessageContent(final IncomingMessage message) { - if (StringUtils.isEmpty(message.content())) { + if (message.content() == null || message.content().length == 0) { logger.warn("Message has no content"); return Optional.empty(); } - try { - return Optional.of(Base64.getDecoder().decode(message.content())); - } catch (final IllegalArgumentException e) { - logger.warn("Failed to parse message content", e); - return Optional.empty(); - } + return Optional.of(message.content()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 7ba048fd7..2549b5521 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -1096,7 +1096,7 @@ class MessageControllerTest { .request() .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(new IncomingMessageList( - List.of(new IncomingMessage(1, (byte) 1, 1, Base64.getEncoder().encodeToString(contentBytes))), false, true, + List.of(new IncomingMessage(1, (byte) 1, 1, contentBytes)), false, true, System.currentTimeMillis()), MediaType.APPLICATION_JSON_TYPE))) { @@ -1642,14 +1642,4 @@ class MessageControllerTest { return builder.build(); } - - @Test - void decodedSize() { - for (int size = MessageController.MAX_MESSAGE_SIZE - 3; size <= MessageController.MAX_MESSAGE_SIZE + 3; size++) { - final byte[] bytes = TestRandomUtil.nextBytes(size); - final String base64Encoded = Base64.getEncoder().encodeToString(bytes); - - assertEquals(bytes.length, MessageController.decodedSize(base64Encoded)); - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java index 4544acd15..a3ed44661 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java @@ -69,7 +69,7 @@ class OutgoingMessageEntityTest { final Account account = new Account(); account.setUuid(UUID.randomUUID()); - IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, "AAAAAA"); + IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4)); MessageProtos.Envelope baseEnvelope = message.toEnvelope( new AciServiceIdentifier(UUID.randomUUID()), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 14787d02e..b3990733c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -15,8 +15,8 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -150,7 +150,7 @@ public class ChangeNumberManagerTest { final IncomingMessage msg = mock(IncomingMessage.class); when(msg.destinationDeviceId()).thenReturn(deviceId2); - when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + when(msg.content()).thenReturn(new byte[]{1}); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds); @@ -203,7 +203,7 @@ public class ChangeNumberManagerTest { final IncomingMessage msg = mock(IncomingMessage.class); when(msg.destinationDeviceId()).thenReturn(deviceId2); - when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + when(msg.content()).thenReturn(new byte[]{1}); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); @@ -254,7 +254,7 @@ public class ChangeNumberManagerTest { final IncomingMessage msg = mock(IncomingMessage.class); when(msg.destinationDeviceId()).thenReturn(deviceId2); - when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + when(msg.content()).thenReturn(new byte[]{1}); changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); @@ -301,7 +301,7 @@ public class ChangeNumberManagerTest { final IncomingMessage msg = mock(IncomingMessage.class); when(msg.destinationDeviceId()).thenReturn(deviceId2); - when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + when(msg.content()).thenReturn(new byte[]{1}); changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds); @@ -350,7 +350,7 @@ public class ChangeNumberManagerTest { final IncomingMessage msg = mock(IncomingMessage.class); when(msg.destinationDeviceId()).thenReturn(deviceId2); - when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + when(msg.content()).thenReturn(new byte[]{1}); changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); @@ -393,8 +393,8 @@ public class ChangeNumberManagerTest { final byte destinationDeviceId2 = 2; final byte destinationDeviceId3 = 3; final List messages = List.of( - new IncomingMessage(1, destinationDeviceId2, 1, "foo"), - new IncomingMessage(1, destinationDeviceId3, 1, "foo")); + new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)), + new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8))); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); @@ -431,8 +431,8 @@ public class ChangeNumberManagerTest { final byte destinationDeviceId2 = 2; final byte destinationDeviceId3 = 3; final List messages = List.of( - new IncomingMessage(1, destinationDeviceId2, 1, "foo"), - new IncomingMessage(1, destinationDeviceId3, 1, "foo")); + new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)), + new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8))); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); @@ -469,8 +469,8 @@ public class ChangeNumberManagerTest { final byte destinationDeviceId2 = 2; final byte destinationDeviceId3 = 3; final List messages = List.of( - new IncomingMessage(1, destinationDeviceId2, 2, "foo"), - new IncomingMessage(1, destinationDeviceId3, 3, "foo")); + new IncomingMessage(1, destinationDeviceId2, 2, "foo".getBytes(StandardCharsets.UTF_8)), + new IncomingMessage(1, destinationDeviceId3, 3, "foo".getBytes(StandardCharsets.UTF_8))); final Map registrationIds = Map.of((byte) 1, 17, destinationDeviceId2, 47, destinationDeviceId3, 89);