diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/MonitoredS3ObjectConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/MonitoredS3ObjectConfiguration.java new file mode 100644 index 000000000..7ffe2dfe2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/MonitoredS3ObjectConfiguration.java @@ -0,0 +1,27 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; +import java.time.Duration; +import javax.validation.constraints.NotBlank; + +public record MonitoredS3ObjectConfiguration( + @NotBlank String s3Region, + @NotBlank String s3Bucket, + @NotBlank String objectKey, + long maxSize, + Duration refreshInterval +) { + + static long DEFAULT_MAXSIZE = 16*1024*1024; + static Duration DEFAULT_DURATION = Duration.ofMinutes(5); + + public MonitoredS3ObjectConfiguration(String s3Region, String s3Bucket, String objectKey) { + this(s3Region, s3Bucket, objectKey, DEFAULT_MAXSIZE, DEFAULT_DURATION); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/s3/ManagedSupplier.java b/service/src/main/java/org/whispersystems/textsecuregcm/s3/ManagedSupplier.java new file mode 100644 index 000000000..27920f26b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/s3/ManagedSupplier.java @@ -0,0 +1,22 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.s3; + +import io.dropwizard.lifecycle.Managed; +import java.util.function.Supplier; + +public interface ManagedSupplier extends Supplier, Managed { + + @Override + default void start() throws Exception { + // noop + } + + @Override + default void stop() throws Exception { + // noop + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/s3/S3MonitoringSupplier.java b/service/src/main/java/org/whispersystems/textsecuregcm/s3/S3MonitoringSupplier.java new file mode 100644 index 000000000..1111c80da --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/s3/S3MonitoringSupplier.java @@ -0,0 +1,93 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.s3; + +import static com.codahale.metrics.MetricRegistry.name; +import static java.util.Objects.requireNonNull; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import java.io.InputStream; +import java.lang.invoke.MethodHandles; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import javax.annotation.Nonnull; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; + +public class S3MonitoringSupplier implements ManagedSupplier { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Nonnull + private final Timer refreshTimer; + + @Nonnull + private final Counter refreshErrors; + + @Nonnull + private final AtomicReference holder; + + @Nonnull + private final S3ObjectMonitor monitor; + + @Nonnull + private final Function parser; + + + public S3MonitoringSupplier( + @Nonnull final ScheduledExecutorService executor, + @Nonnull final MonitoredS3ObjectConfiguration cfg, + @Nonnull final Function parser, + @Nonnull final T initial, + @Nonnull final String name) { + this.refreshTimer = Metrics.timer(name(name, "refresh")); + this.refreshErrors = Metrics.counter(name(name, "refreshErrors")); + this.holder = new AtomicReference<>(initial); + this.parser = requireNonNull(parser); + this.monitor = new S3ObjectMonitor( + cfg.s3Region(), + cfg.s3Bucket(), + cfg.objectKey(), + cfg.maxSize(), + executor, + cfg.refreshInterval(), + this::handleObjectChange + ); + } + + @Override + @Nonnull + public T get() { + return requireNonNull(holder.get()); + } + + @Override + public void start() throws Exception { + monitor.start(); + } + + @Override + public void stop() throws Exception { + monitor.stop(); + } + + private void handleObjectChange(@Nonnull final InputStream inputStream) { + refreshTimer.record(() -> { + // parser function is supposed to close the input stream + try { + holder.set(parser.apply(inputStream)); + } catch (final Exception e) { + log.error("failed to update internal state from the monitored object", e); + refreshErrors.increment(); + } + }); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/s3/S3ObjectMonitor.java b/service/src/main/java/org/whispersystems/textsecuregcm/s3/S3ObjectMonitor.java new file mode 100644 index 000000000..05033e103 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/s3/S3ObjectMonitor.java @@ -0,0 +1,163 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.s3; + +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.io.InputStream; +import java.time.Duration; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.WhisperServerService; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; + +/** + * An S3 object monitor watches a specific object in an S3 bucket and notifies a listener if that object changes. + */ +public class S3ObjectMonitor { + + private final String s3Bucket; + private final String objectKey; + private final long maxObjectSize; + + private final ScheduledExecutorService refreshExecutorService; + private final Duration refreshInterval; + private ScheduledFuture refreshFuture; + + private final Consumer changeListener; + + private final AtomicReference lastETag = new AtomicReference<>(); + + private final S3Client s3Client; + + private static final Logger log = LoggerFactory.getLogger(S3ObjectMonitor.class); + + public S3ObjectMonitor( + final String s3Region, + final String s3Bucket, + final String objectKey, + final long maxObjectSize, + final ScheduledExecutorService refreshExecutorService, + final Duration refreshInterval, + final Consumer changeListener) { + + this(S3Client.builder() + .region(Region.of(s3Region)) + .credentialsProvider(WhisperServerService.AWSSDK_CREDENTIALS_PROVIDER) + .build(), + s3Bucket, + objectKey, + maxObjectSize, + refreshExecutorService, + refreshInterval, + changeListener); + } + + @VisibleForTesting + S3ObjectMonitor( + final S3Client s3Client, + final String s3Bucket, + final String objectKey, + final long maxObjectSize, + final ScheduledExecutorService refreshExecutorService, + final Duration refreshInterval, + final Consumer changeListener) { + + this.s3Client = s3Client; + this.s3Bucket = s3Bucket; + this.objectKey = objectKey; + this.maxObjectSize = maxObjectSize; + + this.refreshExecutorService = refreshExecutorService; + this.refreshInterval = refreshInterval; + + this.changeListener = changeListener; + } + + public synchronized void start() { + if (refreshFuture != null) { + throw new RuntimeException("S3 object manager already started"); + } + + // Run the first request immediately/blocking, then start subsequent calls. + log.info("Initial request for s3://{}/{}", s3Bucket, objectKey); + refresh(); + + refreshFuture = refreshExecutorService + .scheduleAtFixedRate(this::refresh, refreshInterval.toMillis(), refreshInterval.toMillis(), + TimeUnit.MILLISECONDS); + } + + public synchronized void stop() { + if (refreshFuture != null) { + refreshFuture.cancel(true); + } + } + + /** + * Immediately returns the monitored S3 object regardless of whether it has changed since it was last retrieved. + * + * @return the current version of the monitored S3 object. Caller should close() this upon completion. + * @throws IOException if the retrieved S3 object is larger than the configured maximum size + */ + @VisibleForTesting + ResponseInputStream getObject() throws IOException { + final ResponseInputStream response = s3Client.getObject(GetObjectRequest.builder() + .key(objectKey) + .bucket(s3Bucket) + .build()); + + lastETag.set(response.response().eTag()); + + if (response.response().contentLength() <= maxObjectSize) { + return response; + } else { + log.warn("Object at s3://{}/{} has a size of {} bytes, which exceeds the maximum allowed size of {} bytes", + s3Bucket, objectKey, response.response().contentLength(), maxObjectSize); + response.abort(); + throw new IOException("S3 object too large"); + } + } + + /** + * Polls S3 for object metadata and notifies the listener provided at construction time if and only if the object has + * changed since the last call to {@link #getObject()} or {@code refresh()}. + */ + @VisibleForTesting + void refresh() { + try { + final HeadObjectResponse objectMetadata = s3Client.headObject(HeadObjectRequest.builder() + .bucket(s3Bucket) + .key(objectKey) + .build()); + + final String initialETag = lastETag.get(); + final String refreshedETag = objectMetadata.eTag(); + + if (!StringUtils.equals(initialETag, refreshedETag) && lastETag.compareAndSet(initialETag, refreshedETag)) { + try (final ResponseInputStream response = getObject()) { + log.info("Object at s3://{}/{} has changed; new eTag is {} and object size is {} bytes", + s3Bucket, objectKey, response.response().eTag(), response.response().contentLength()); + changeListener.accept(response); + } + } + } catch (final Exception e) { + log.warn("Failed to refresh monitored object", e); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/s3/S3ObjectMonitorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/s3/S3ObjectMonitorTest.java new file mode 100644 index 000000000..cea5b33c2 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/s3/S3ObjectMonitorTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.s3; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; + +class S3ObjectMonitorTest { + + @Test + void refresh() { + final S3Client s3Client = mock(S3Client.class); + + final String bucket = "s3bucket"; + final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; + + //noinspection unchecked + final Consumer listener = mock(Consumer.class); + + final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( + s3Client, + bucket, + objectKey, + 16 * 1024 * 1024, + mock(ScheduledExecutorService.class), + Duration.ofMinutes(1), + listener); + + final String uuid = UUID.randomUUID().toString(); + when(s3Client.headObject(HeadObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn( + HeadObjectResponse.builder().eTag(uuid).build()); + final ResponseInputStream ris = responseInputStreamFromString("abc", uuid); + when(s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(ris); + + objectMonitor.refresh(); + objectMonitor.refresh(); + + verify(listener).accept(ris); + } + + @Test + void refreshAfterGet() throws IOException { + final S3Client s3Client = mock(S3Client.class); + + final String bucket = "s3bucket"; + final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; + + //noinspection unchecked + final Consumer listener = mock(Consumer.class); + + final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( + s3Client, + bucket, + objectKey, + 16 * 1024 * 1024, + mock(ScheduledExecutorService.class), + Duration.ofMinutes(1), + listener); + + final String uuid = UUID.randomUUID().toString(); + when(s3Client.headObject(HeadObjectRequest.builder().key(objectKey).bucket(bucket).build())) + .thenReturn(HeadObjectResponse.builder().eTag(uuid).build()); + final ResponseInputStream responseInputStream = responseInputStreamFromString("abc", uuid); + when(s3Client.getObject(GetObjectRequest.builder().key(objectKey).bucket(bucket).build())).thenReturn(responseInputStream); + + objectMonitor.getObject(); + objectMonitor.refresh(); + + verify(listener, never()).accept(responseInputStream); + } + + private ResponseInputStream responseInputStreamFromString(final String s, final String etag) { + final byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + final AbortableInputStream ais = AbortableInputStream.create(new ByteArrayInputStream(bytes)); + return new ResponseInputStream<>(GetObjectResponse.builder().contentLength((long) bytes.length).eTag(etag).build(), ais); + } + + @Test + void refreshOversizedObject() { + final S3Client s3Client = mock(S3Client.class); + + final String bucket = "s3bucket"; + final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; + final long maxObjectSize = 16 * 1024 * 1024; + + //noinspection unchecked + final Consumer listener = mock(Consumer.class); + + final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( + s3Client, + bucket, + objectKey, + maxObjectSize, + mock(ScheduledExecutorService.class), + Duration.ofMinutes(1), + listener); + + final String uuid = UUID.randomUUID().toString(); + when(s3Client.headObject(HeadObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn( + HeadObjectResponse.builder().eTag(uuid).contentLength(maxObjectSize+1).build()); + final ResponseInputStream ris = responseInputStreamFromString("a".repeat((int) maxObjectSize+1), uuid); + when(s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(ris); + + objectMonitor.refresh(); + + verify(listener, never()).accept(any()); + } +}