diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/AsnManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/AsnManager.java index c373ba988..a87ae0a10 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/AsnManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/AsnManager.java @@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.util; import static com.codahale.metrics.MetricRegistry.name; -import com.amazonaws.services.s3.model.S3Object; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.lifecycle.Managed; import io.micrometer.core.instrument.Counter; @@ -25,6 +24,8 @@ import java.util.zip.GZIPInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; public class AsnManager implements Managed { @@ -53,12 +54,6 @@ public class AsnManager implements Managed { @Override public void start() throws Exception { - try { - handleAsnTableChanged(asnTableMonitor.getObject()); - } catch (final Exception e) { - log.warn("Failed to load initial IP-to-ASN map", e); - } - asnTableMonitor.start(); } @@ -76,10 +71,10 @@ public class AsnManager implements Managed { } } - private void handleAsnTableChanged(final S3Object asnTableObject) { + private void handleAsnTableChanged(final ResponseInputStream asnTableObject) { REFRESH_TIMER.record(() -> { try { - handleAsnTableChanged(new GZIPInputStream(asnTableObject.getObjectContent())); + handleAsnTableChangedStream(new GZIPInputStream(asnTableObject)); } catch (final IOException e) { log.error("Retrieved object was not a gzip archive", e); } @@ -87,7 +82,7 @@ public class AsnManager implements Managed { } @VisibleForTesting - void handleAsnTableChanged(final InputStream inputStream) { + void handleAsnTableChangedStream(final InputStream inputStream) { try (final InputStreamReader reader = new InputStreamReader(inputStream)) { asnTable.set(new AsnTable(reader)); } catch (final Exception e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitor.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitor.java index 61b5896f8..a7d0d7da3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitor.java @@ -5,11 +5,6 @@ package org.whispersystems.textsecuregcm.util; -import com.amazonaws.auth.InstanceProfileCredentialsProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.S3Object; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.lifecycle.Managed; import java.io.IOException; @@ -22,6 +17,14 @@ import java.util.function.Consumer; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; /** * An S3 object monitor watches a specific object in an S3 bucket and notifies a listener if that object changes. @@ -36,11 +39,11 @@ public class S3ObjectMonitor implements Managed { private final Duration refreshInterval; private ScheduledFuture refreshFuture; - private final Consumer changeListener; + private final Consumer> changeListener; private final AtomicReference lastETag = new AtomicReference<>(); - private final AmazonS3 s3Client; + private final S3Client s3Client; private static final Logger log = LoggerFactory.getLogger(S3ObjectMonitor.class); @@ -51,11 +54,11 @@ public class S3ObjectMonitor implements Managed { final long maxObjectSize, final ScheduledExecutorService refreshExecutorService, final Duration refreshInterval, - final Consumer changeListener) { + final Consumer> changeListener) { - this(AmazonS3ClientBuilder.standard() - .withCredentials(InstanceProfileCredentialsProvider.getInstance()) - .withRegion(s3Region) + this(S3Client.builder() + .region(Region.of(s3Region)) + .credentialsProvider(InstanceProfileCredentialsProvider.create()) .build(), s3Bucket, objectKey, @@ -67,13 +70,13 @@ public class S3ObjectMonitor implements Managed { @VisibleForTesting S3ObjectMonitor( - final AmazonS3 s3Client, + final S3Client s3Client, final String s3Bucket, final String objectKey, final long maxObjectSize, final ScheduledExecutorService refreshExecutorService, final Duration refreshInterval, - final Consumer changeListener) { + final Consumer> changeListener) { this.s3Client = s3Client; this.s3Bucket = s3Bucket; @@ -92,8 +95,13 @@ public class S3ObjectMonitor implements Managed { throw new RuntimeException("S3 object manager already started"); } + // Run the first request immediately/blocking, then start subsequent calls. + log.info("Initial request for s3://{}/{}", s3Bucket, objectKey); + refresh(); + refreshFuture = refreshExecutorService - .scheduleAtFixedRate(this::refresh, refreshInterval.toMillis(), refreshInterval.toMillis(), TimeUnit.MILLISECONDS); + .scheduleAtFixedRate(this::refresh, refreshInterval.toMillis(), refreshInterval.toMillis(), + TimeUnit.MILLISECONDS); } @Override @@ -106,21 +114,24 @@ public class S3ObjectMonitor implements Managed { /** * Immediately returns the monitored S3 object regardless of whether it has changed since it was last retrieved. * - * @return the current version of the monitored S3 object - * + * @return the current version of the monitored S3 object. Caller should close() this upon completion. * @throws IOException if the retrieved S3 object is larger than the configured maximum size */ - public S3Object getObject() throws IOException { - final S3Object s3Object = s3Client.getObject(s3Bucket, objectKey); + @VisibleForTesting + ResponseInputStream getObject() throws IOException { + ResponseInputStream response = s3Client.getObject(GetObjectRequest.builder() + .key(objectKey) + .bucket(s3Bucket) + .build()); - lastETag.set(s3Object.getObjectMetadata().getETag()); + lastETag.set(response.response().eTag()); - if (s3Object.getObjectMetadata().getContentLength() <= maxObjectSize) { - return s3Object; + if (response.response().contentLength() <= maxObjectSize) { + return response; } else { log.warn("Object at s3://{}/{} has a size of {} bytes, which exceeds the maximum allowed size of {} bytes", - s3Bucket, objectKey, s3Object.getObjectMetadata().getContentLength(), maxObjectSize); - + s3Bucket, objectKey, response.response().contentLength(), maxObjectSize); + response.abort(); throw new IOException("S3 object too large"); } } @@ -132,18 +143,25 @@ public class S3ObjectMonitor implements Managed { @VisibleForTesting void refresh() { try { - final ObjectMetadata objectMetadata = s3Client.getObjectMetadata(s3Bucket, objectKey); + final HeadObjectResponse objectMetadata = s3Client.headObject(HeadObjectRequest.builder() + .bucket(s3Bucket) + .key(objectKey) + .build()); final String initialETag = lastETag.get(); - final String refreshedETag = objectMetadata.getETag(); + final String refreshedETag = objectMetadata.eTag(); if (!StringUtils.equals(initialETag, refreshedETag) && lastETag.compareAndSet(initialETag, refreshedETag)) { - final S3Object s3Object = getObject(); + final ResponseInputStream response = getObject(); log.info("Object at s3://{}/{} has changed; new eTag is {} and object size is {} bytes", - s3Bucket, objectKey, s3Object.getObjectMetadata().getETag(), s3Object.getObjectMetadata().getContentLength()); + s3Bucket, objectKey, response.response().eTag(), response.response().contentLength()); - changeListener.accept(s3Object); + try { + changeListener.accept(response); + } finally { + response.close(); + } } } catch (final Exception e) { log.warn("Failed to refresh monitored object", e); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/TorExitNodeManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/TorExitNodeManager.java index 00e758972..a4dfbaf94 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/TorExitNodeManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/TorExitNodeManager.java @@ -7,14 +7,12 @@ package org.whispersystems.textsecuregcm.util; import static com.codahale.metrics.MetricRegistry.name; -import com.amazonaws.services.s3.model.S3Object; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.lifecycle.Managed; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.io.BufferedReader; -import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.Collections; @@ -25,6 +23,8 @@ import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; /** * A utility for checking whether IP addresses belong to Tor exit nodes using the "bulk exit list." @@ -58,12 +58,6 @@ public class TorExitNodeManager implements Managed { @Override public synchronized void start() { - try { - handleExitListChanged(exitListMonitor.getObject()); - } catch (final Exception e) { - log.warn("Failed to load initial Tor exit node list", e); - } - exitListMonitor.start(); } @@ -76,12 +70,12 @@ public class TorExitNodeManager implements Managed { return exitNodeAddresses.get().contains(address); } - private void handleExitListChanged(final S3Object exitList) { - REFRESH_TIMER.record(() -> handleExitListChanged(exitList.getObjectContent())); + private void handleExitListChanged(final ResponseInputStream exitList) { + REFRESH_TIMER.record(() -> handleExitListChanged(exitList)); } @VisibleForTesting - void handleExitListChanged(final InputStream inputStream) { + void handleExitListChangedStream(final InputStream inputStream) { try (final BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { exitNodeAddresses.set(reader.lines().collect(Collectors.toSet())); } catch (final Exception e) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/AsnManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/AsnManagerTest.java index ad756362e..876d8ddf2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/AsnManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/AsnManagerTest.java @@ -8,17 +8,11 @@ package org.whispersystems.textsecuregcm.util; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.io.InputStreamReader; -import java.net.Inet4Address; -import java.nio.charset.StandardCharsets; import java.util.Optional; import java.util.concurrent.ScheduledExecutorService; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.mock; @@ -34,7 +28,7 @@ class AsnManagerTest { assertEquals(Optional.empty(), asnManager.getAsn("10.0.0.1")); try (final InputStream tableInputStream = getClass().getResourceAsStream("ip2asn-test.tsv")) { - asnManager.handleAsnTableChanged(tableInputStream); + asnManager.handleAsnTableChangedStream(tableInputStream); } assertEquals(Optional.of(7922L), asnManager.getAsn("50.79.54.1")); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitorTest.java index acb36d1f3..8e373c99a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/S3ObjectMonitorTest.java @@ -1,11 +1,17 @@ package org.whispersystems.textsecuregcm.util; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.S3Object; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.UUID; import java.util.concurrent.ScheduledExecutorService; @@ -21,15 +27,13 @@ class S3ObjectMonitorTest { @Test void refresh() { - final AmazonS3 s3Client = mock(AmazonS3.class); - final ObjectMetadata metadata = mock(ObjectMetadata.class); - final S3Object s3Object = mock(S3Object.class); + final S3Client s3Client = mock(S3Client.class); final String bucket = "s3bucket"; final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; //noinspection unchecked - final Consumer listener = mock(Consumer.class); + final Consumer> listener = mock(Consumer.class); final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( s3Client, @@ -40,28 +44,27 @@ class S3ObjectMonitorTest { Duration.ofMinutes(1), listener); - when(metadata.getETag()).thenReturn(UUID.randomUUID().toString()); - when(s3Object.getObjectMetadata()).thenReturn(metadata); - when(s3Client.getObjectMetadata(bucket, objectKey)).thenReturn(metadata); - when(s3Client.getObject(bucket, objectKey)).thenReturn(s3Object); + String uuid = UUID.randomUUID().toString(); + when(s3Client.headObject(HeadObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn( + HeadObjectResponse.builder().eTag(uuid).build()); + ResponseInputStream ris = responseInputStreamFromString("abc", uuid); + when(s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(ris); objectMonitor.refresh(); objectMonitor.refresh(); - verify(listener).accept(s3Object); + verify(listener).accept(ris); } @Test void refreshAfterGet() throws IOException { - final AmazonS3 s3Client = mock(AmazonS3.class); - final ObjectMetadata metadata = mock(ObjectMetadata.class); - final S3Object s3Object = mock(S3Object.class); + final S3Client s3Client = mock(S3Client.class); final String bucket = "s3bucket"; final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; //noinspection unchecked - final Consumer listener = mock(Consumer.class); + final Consumer> listener = mock(Consumer.class); final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( s3Client, @@ -72,29 +75,34 @@ class S3ObjectMonitorTest { Duration.ofMinutes(1), listener); - when(metadata.getETag()).thenReturn(UUID.randomUUID().toString()); - when(s3Object.getObjectMetadata()).thenReturn(metadata); - when(s3Client.getObjectMetadata(bucket, objectKey)).thenReturn(metadata); - when(s3Client.getObject(bucket, objectKey)).thenReturn(s3Object); + String uuid = UUID.randomUUID().toString(); + when(s3Client.headObject(HeadObjectRequest.builder().key(objectKey).bucket(bucket).build())) + .thenReturn(HeadObjectResponse.builder().eTag(uuid).build()); + ResponseInputStream responseInputStream = responseInputStreamFromString("abc", uuid); + when(s3Client.getObject(GetObjectRequest.builder().key(objectKey).bucket(bucket).build())).thenReturn(responseInputStream); objectMonitor.getObject(); objectMonitor.refresh(); - verify(listener, never()).accept(s3Object); + verify(listener, never()).accept(responseInputStream); + } + + private ResponseInputStream responseInputStreamFromString(String s, String etag) { + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); + AbortableInputStream ais = AbortableInputStream.create(new ByteArrayInputStream(bytes)); + return new ResponseInputStream<>(GetObjectResponse.builder().contentLength((long) bytes.length).eTag(etag).build(), ais); } @Test void refreshOversizedObject() { - final AmazonS3 s3Client = mock(AmazonS3.class); - final ObjectMetadata metadata = mock(ObjectMetadata.class); - final S3Object s3Object = mock(S3Object.class); + final S3Client s3Client = mock(S3Client.class); final String bucket = "s3bucket"; final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; final long maxObjectSize = 16 * 1024 * 1024; //noinspection unchecked - final Consumer listener = mock(Consumer.class); + final Consumer> listener = mock(Consumer.class); final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( s3Client, @@ -105,11 +113,11 @@ class S3ObjectMonitorTest { Duration.ofMinutes(1), listener); - when(metadata.getETag()).thenReturn(UUID.randomUUID().toString()); - when(metadata.getContentLength()).thenReturn(maxObjectSize + 1); - when(s3Object.getObjectMetadata()).thenReturn(metadata); - when(s3Client.getObjectMetadata(bucket, objectKey)).thenReturn(metadata); - when(s3Client.getObject(bucket, objectKey)).thenReturn(s3Object); + String uuid = UUID.randomUUID().toString(); + when(s3Client.headObject(HeadObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn( + HeadObjectResponse.builder().eTag(uuid).contentLength(maxObjectSize+1).build()); + ResponseInputStream ris = responseInputStreamFromString("a".repeat((int) maxObjectSize+1), uuid); + when(s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(ris); objectMonitor.refresh(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java index 76f25d0e6..55554459d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java @@ -29,7 +29,7 @@ public class TorExitNodeManagerTest extends AbstractRedisClusterTest { assertFalse(torExitNodeManager.isTorExitNode("10.0.0.1")); assertFalse(torExitNodeManager.isTorExitNode("10.0.0.2")); - torExitNodeManager.handleExitListChanged( + torExitNodeManager.handleExitListChangedStream( new ByteArrayInputStream("10.0.0.1\n10.0.0.2".getBytes(StandardCharsets.UTF_8))); assertTrue(torExitNodeManager.isTorExitNode("10.0.0.1"));