Parse message content as a byte array in request entities
This commit is contained in:
parent
db2cd20dcb
commit
faef614d80
|
@ -8,7 +8,6 @@ package org.signal.integration;
|
|||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -27,11 +26,10 @@ public class MessagingTest {
|
|||
|
||||
try {
|
||||
final byte[] expectedContent = "Hello, World!".getBytes(StandardCharsets.UTF_8);
|
||||
final String contentBase64 = Base64.getEncoder().encodeToString(expectedContent);
|
||||
final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), contentBase64);
|
||||
final IncomingMessage message = new IncomingMessage(1, Device.PRIMARY_ID, userB.registrationId(), expectedContent);
|
||||
final IncomingMessageList messages = new IncomingMessageList(List.of(message), false, true, System.currentTimeMillis());
|
||||
|
||||
final Pair<Integer, SendMessageResponse> sendMessage = Operations
|
||||
Operations
|
||||
.apiPut("/v1/messages/%s".formatted(userB.aciUuid().toString()), messages)
|
||||
.authorized(userA)
|
||||
.execute(SendMessageResponse.class);
|
||||
|
@ -40,7 +38,7 @@ public class MessagingTest {
|
|||
.authorized(userB)
|
||||
.execute(OutgoingMessageEntityList.class);
|
||||
|
||||
final byte[] actualContent = receiveMessages.getRight().messages().get(0).content();
|
||||
final byte[] actualContent = receiveMessages.getRight().messages().getFirst().content();
|
||||
assertArrayEquals(expectedContent, actualContent);
|
||||
} finally {
|
||||
Operations.deleteUser(userA);
|
||||
|
|
|
@ -62,7 +62,6 @@ import java.util.concurrent.ExecutionException;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.glassfish.jersey.server.ManagedAsync;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.signal.libsignal.protocol.ServiceId;
|
||||
|
@ -324,7 +323,7 @@ public class MessageController {
|
|||
int totalContentLength = 0;
|
||||
|
||||
for (final IncomingMessage message : messages.messages()) {
|
||||
final int contentLength = decodedSize(message.content());
|
||||
final int contentLength = message.content() != null ? message.content().length : 0;
|
||||
|
||||
validateContentLength(contentLength, false, isSyncMessage, isStory, userAgent);
|
||||
|
||||
|
@ -955,19 +954,4 @@ public class MessageController {
|
|||
.increment();
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static int decodedSize(final String base64) {
|
||||
final int padding;
|
||||
|
||||
if (StringUtils.endsWith(base64, "==")) {
|
||||
padding = 2;
|
||||
} else if (StringUtils.endsWith(base64, "=")) {
|
||||
padding = 1;
|
||||
} else {
|
||||
padding = 0;
|
||||
}
|
||||
|
||||
return ((StringUtils.length(base64) - padding) * 3) / 4;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,18 +4,25 @@
|
|||
*/
|
||||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.webauthn4j.converter.jackson.deserializer.json.ByteArrayBase64Deserializer;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import jakarta.validation.constraints.AssertTrue;
|
||||
import java.util.Base64;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
|
||||
public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) {
|
||||
public record IncomingMessage(int type,
|
||||
byte destinationDeviceId,
|
||||
int destinationRegistrationId,
|
||||
|
||||
@JsonDeserialize(using = ByteArrayBase64Deserializer.class)
|
||||
byte[] content) {
|
||||
|
||||
private static final String REJECT_INVALID_ENVELOPE_TYPE_COUNTER_NAME =
|
||||
MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType");
|
||||
|
@ -50,8 +57,8 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
|
|||
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken));
|
||||
}
|
||||
|
||||
if (StringUtils.isNotEmpty(content())) {
|
||||
envelopeBuilder.setContent(ByteString.copyFrom(Base64.getDecoder().decode(content())));
|
||||
if (content() != null && content().length > 0) {
|
||||
envelopeBuilder.setContent(ByteString.copyFrom(content()));
|
||||
}
|
||||
|
||||
return envelopeBuilder.build();
|
||||
|
@ -69,4 +76,17 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(final Object o) {
|
||||
if (!(o instanceof IncomingMessage(int otherType, byte otherDeviceId, int otherRegistrationId, byte[] otherContent)))
|
||||
return false;
|
||||
return type == otherType && destinationDeviceId == otherDeviceId
|
||||
&& destinationRegistrationId == otherRegistrationId && Objects.deepEquals(content, otherContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(type, destinationDeviceId, destinationRegistrationId, Arrays.hashCode(content));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,6 @@ import java.util.Set;
|
|||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.ObjectUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -138,16 +137,11 @@ public class ChangeNumberManager {
|
|||
}
|
||||
|
||||
private static Optional<byte[]> getMessageContent(final IncomingMessage message) {
|
||||
if (StringUtils.isEmpty(message.content())) {
|
||||
if (message.content() == null || message.content().length == 0) {
|
||||
logger.warn("Message has no content");
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
try {
|
||||
return Optional.of(Base64.getDecoder().decode(message.content()));
|
||||
} catch (final IllegalArgumentException e) {
|
||||
logger.warn("Failed to parse message content", e);
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.of(message.content());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1096,7 +1096,7 @@ class MessageControllerTest {
|
|||
.request()
|
||||
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
|
||||
.put(Entity.entity(new IncomingMessageList(
|
||||
List.of(new IncomingMessage(1, (byte) 1, 1, Base64.getEncoder().encodeToString(contentBytes))), false, true,
|
||||
List.of(new IncomingMessage(1, (byte) 1, 1, contentBytes)), false, true,
|
||||
System.currentTimeMillis()),
|
||||
MediaType.APPLICATION_JSON_TYPE))) {
|
||||
|
||||
|
@ -1642,14 +1642,4 @@ class MessageControllerTest {
|
|||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
void decodedSize() {
|
||||
for (int size = MessageController.MAX_MESSAGE_SIZE - 3; size <= MessageController.MAX_MESSAGE_SIZE + 3; size++) {
|
||||
final byte[] bytes = TestRandomUtil.nextBytes(size);
|
||||
final String base64Encoded = Base64.getEncoder().encodeToString(bytes);
|
||||
|
||||
assertEquals(bytes.length, MessageController.decodedSize(base64Encoded));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ class OutgoingMessageEntityTest {
|
|||
final Account account = new Account();
|
||||
account.setUuid(UUID.randomUUID());
|
||||
|
||||
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, "AAAAAA");
|
||||
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
|
||||
|
||||
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
|
||||
new AciServiceIdentifier(UUID.randomUUID()),
|
||||
|
|
|
@ -15,8 +15,8 @@ import static org.mockito.Mockito.never;
|
|||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -150,7 +150,7 @@ public class ChangeNumberManagerTest {
|
|||
|
||||
final IncomingMessage msg = mock(IncomingMessage.class);
|
||||
when(msg.destinationDeviceId()).thenReturn(deviceId2);
|
||||
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
|
||||
when(msg.content()).thenReturn(new byte[]{1});
|
||||
|
||||
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
|
||||
|
||||
|
@ -203,7 +203,7 @@ public class ChangeNumberManagerTest {
|
|||
|
||||
final IncomingMessage msg = mock(IncomingMessage.class);
|
||||
when(msg.destinationDeviceId()).thenReturn(deviceId2);
|
||||
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
|
||||
when(msg.content()).thenReturn(new byte[]{1});
|
||||
|
||||
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
|
||||
|
||||
|
@ -254,7 +254,7 @@ public class ChangeNumberManagerTest {
|
|||
|
||||
final IncomingMessage msg = mock(IncomingMessage.class);
|
||||
when(msg.destinationDeviceId()).thenReturn(deviceId2);
|
||||
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
|
||||
when(msg.content()).thenReturn(new byte[]{1});
|
||||
|
||||
changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
|
||||
|
||||
|
@ -301,7 +301,7 @@ public class ChangeNumberManagerTest {
|
|||
|
||||
final IncomingMessage msg = mock(IncomingMessage.class);
|
||||
when(msg.destinationDeviceId()).thenReturn(deviceId2);
|
||||
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
|
||||
when(msg.content()).thenReturn(new byte[]{1});
|
||||
|
||||
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
|
||||
|
||||
|
@ -350,7 +350,7 @@ public class ChangeNumberManagerTest {
|
|||
|
||||
final IncomingMessage msg = mock(IncomingMessage.class);
|
||||
when(msg.destinationDeviceId()).thenReturn(deviceId2);
|
||||
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
|
||||
when(msg.content()).thenReturn(new byte[]{1});
|
||||
|
||||
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
|
||||
|
||||
|
@ -393,8 +393,8 @@ public class ChangeNumberManagerTest {
|
|||
final byte destinationDeviceId2 = 2;
|
||||
final byte destinationDeviceId3 = 3;
|
||||
final List<IncomingMessage> messages = List.of(
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
|
||||
|
@ -431,8 +431,8 @@ public class ChangeNumberManagerTest {
|
|||
final byte destinationDeviceId2 = 2;
|
||||
final byte destinationDeviceId3 = 3;
|
||||
final List<IncomingMessage> messages = List.of(
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
|
||||
|
@ -469,8 +469,8 @@ public class ChangeNumberManagerTest {
|
|||
final byte destinationDeviceId2 = 2;
|
||||
final byte destinationDeviceId3 = 3;
|
||||
final List<IncomingMessage> messages = List.of(
|
||||
new IncomingMessage(1, destinationDeviceId2, 2, "foo"),
|
||||
new IncomingMessage(1, destinationDeviceId3, 3, "foo"));
|
||||
new IncomingMessage(1, destinationDeviceId2, 2, "foo".getBytes(StandardCharsets.UTF_8)),
|
||||
new IncomingMessage(1, destinationDeviceId3, 3, "foo".getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final Map<Byte, Integer> registrationIds = Map.of((byte) 1, 17, destinationDeviceId2, 47, destinationDeviceId3, 89);
|
||||
|
||||
|
|
Loading…
Reference in New Issue