Parse message content as a byte array in request entities

This commit is contained in:
Jon Chambers 2025-02-10 14:53:31 -05:00 committed by Jon Chambers
parent db2cd20dcb
commit faef614d80
7 changed files with 45 additions and 59 deletions

View File

@ -8,7 +8,6 @@ package org.signal.integration;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Test;
@ -27,11 +26,10 @@ public class MessagingTest {
try {
final byte[] expectedContent = "Hello, World!".getBytes(StandardCharsets.UTF_8);
final String contentBase64 = Base64.getEncoder().encodeToString(expectedContent);
final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), contentBase64);
final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), expectedContent);
final IncomingMessageList messages = new IncomingMessageList(List.of(message), false, true, System.currentTimeMillis());
final Pair<Integer, SendMessageResponse> sendMessage = Operations
Operations
.apiPut("/v1/messages/%s".formatted(userB.aciUuid().toString()), messages)
.authorized(userA)
.execute(SendMessageResponse.class);
@ -40,7 +38,7 @@ public class MessagingTest {
.authorized(userB)
.execute(OutgoingMessageEntityList.class);
final byte[] actualContent = receiveMessages.getRight().messages().get(0).content();
final byte[] actualContent = receiveMessages.getRight().messages().getFirst().content();
assertArrayEquals(expectedContent, actualContent);
} finally {
Operations.deleteUser(userA);

View File

@ -62,7 +62,6 @@ import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
@ -324,7 +323,7 @@ public class MessageController {
int totalContentLength = 0;
for (final IncomingMessage message : messages.messages()) {
final int contentLength = decodedSize(message.content());
final int contentLength = message.content() != null ? message.content().length : 0;
validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent);
@ -955,19 +954,4 @@ public class MessageController {
.increment();
}
}
@VisibleForTesting
static int decodedSize(final String base64) {
final int padding;
if (StringUtils.endsWith(base64, "==")) {
padding = 2;
} else if (StringUtils.endsWith(base64, "=")) {
padding = 1;
} else {
padding = 0;
}
return ((StringUtils.length(base64) - padding) * 3) / 4;
}
}

View File

@ -4,18 +4,25 @@
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.protobuf.ByteString;
import com.webauthn4j.converter.jackson.deserializer.json.ByteArrayBase64Deserializer;
import io.micrometer.core.instrument.Metrics;
import jakarta.validation.constraints.AssertTrue;
import java.util.Base64;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import java.util.Arrays;
import java.util.Objects;
public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) {
public record IncomingMessage(int type,
byte destinationDeviceId,
int destinationRegistrationId,
@JsonDeserialize(using = ByteArrayBase64Deserializer.class)
byte[] content) {
private static final String REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME =
MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType");
@ -50,8 +57,8 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken));
}
if (StringUtils.isNotEmpty(content())) {
envelopeBuilder.setContent(ByteString.copyFrom(Base64.getDecoder().decode(content())));
if (content() != null && content().length > 0) {
envelopeBuilder.setContent(ByteString.copyFrom(content()));
}
return envelopeBuilder.build();
@ -69,4 +76,17 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
return true;
}
@Override
public boolean equals(final Object o) {
if (!(o instanceof IncomingMessage(int otherType, byte otherDeviceId, int otherRegistrationId, byte[] otherContent)))
return false;
return type == otherType && destinationDeviceId == otherDeviceId
&& destinationRegistrationId == otherRegistrationId && Objects.deepEquals(content, otherContent);
}
@Override
public int hashCode() {
return Objects.hash(type, destinationDeviceId, destinationRegistrationId, Arrays.hashCode(content));
}
}

View File

@ -13,7 +13,6 @@ import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -138,16 +137,11 @@ public class ChangeNumberManager {
}
private static Optional<byte[]> getMessageContent(final IncomingMessage message) {
if (StringUtils.isEmpty(message.content())) {
if (message.content() == null || message.content().length == 0) {
logger.warn("Message has no content");
return Optional.empty();
}
try {
return Optional.of(Base64.getDecoder().decode(message.content()));
} catch (final IllegalArgumentException e) {
logger.warn("Failed to parse message content", e);
return Optional.empty();
}
return Optional.of(message.content());
}
}

View File

@ -1096,7 +1096,7 @@ class MessageControllerTest {
.request()
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, (byte) 1, 1, Base64.getEncoder().encodeToString(contentBytes))), false, true,
List.of(new IncomingMessage(1, (byte) 1, 1, contentBytes)), false, true,
System.currentTimeMillis()),
MediaType.APPLICATION_JSON_TYPE))) {
@ -1642,14 +1642,4 @@ class MessageControllerTest {
return builder.build();
}
@Test
void decodedSize() {
for (int size = MessageController.MAX_MESSAGE_SIZE - 3; size <= MessageController.MAX_MESSAGE_SIZE + 3; size++) {
final byte[] bytes = TestRandomUtil.nextBytes(size);
final String base64Encoded = Base64.getEncoder().encodeToString(bytes);
assertEquals(bytes.length, MessageController.decodedSize(base64Encoded));
}
}
}

View File

@ -69,7 +69,7 @@ class OutgoingMessageEntityTest {
final Account account = new Account();
account.setUuid(UUID.randomUUID());
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, "AAAAAA");
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
new AciServiceIdentifier(UUID.randomUUID()),

View File

@ -15,8 +15,8 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -150,7 +150,7 @@ public class ChangeNumberManagerTest {
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
@ -203,7 +203,7 @@ public class ChangeNumberManagerTest {
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@ -254,7 +254,7 @@ public class ChangeNumberManagerTest {
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@ -301,7 +301,7 @@ public class ChangeNumberManagerTest {
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
@ -350,7 +350,7 @@ public class ChangeNumberManagerTest {
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@ -393,8 +393,8 @@ public class ChangeNumberManagerTest {
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
@ -431,8 +431,8 @@ public class ChangeNumberManagerTest {
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
@ -469,8 +469,8 @@ public class ChangeNumberManagerTest {
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 2, "foo"),
new IncomingMessage(1, destinationDeviceId3, 3, "foo"));
new IncomingMessage(1, destinationDeviceId2, 2, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 3, "foo".getBytes(StandardCharsets.UTF_8)));
final Map<Byte, Integer> registrationIds = Map.of((byte) 1, 17, destinationDeviceId2, 47, destinationDeviceId3, 89);