Switch S3ObjectMonitor to AWSv2 SDK.

This commit is contained in:
Graeme Connell 2021-05-20 16:09:12 -06:00 committed by gram-signal
parent 680e501f83
commit 722055c8b5
6 changed files with 96 additions and 87 deletions

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.util;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import com.amazonaws.services.s3.model.S3Object;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
@ -25,6 +24,8 @@ import java.util.zip.GZIPInputStream;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
public class AsnManager implements Managed { public class AsnManager implements Managed {
@ -53,12 +54,6 @@ public class AsnManager implements Managed {
@Override @Override
public void start() throws Exception { public void start() throws Exception {
try {
handleAsnTableChanged(asnTableMonitor.getObject());
} catch (final Exception e) {
log.warn("Failed to load initial IP-to-ASN map", e);
}
asnTableMonitor.start(); asnTableMonitor.start();
} }
@ -76,10 +71,10 @@ public class AsnManager implements Managed {
} }
} }
private void handleAsnTableChanged(final S3Object asnTableObject) { private void handleAsnTableChanged(final ResponseInputStream<GetObjectResponse> asnTableObject) {
REFRESH_TIMER.record(() -> { REFRESH_TIMER.record(() -> {
try { try {
handleAsnTableChanged(new GZIPInputStream(asnTableObject.getObjectContent())); handleAsnTableChangedStream(new GZIPInputStream(asnTableObject));
} catch (final IOException e) { } catch (final IOException e) {
log.error("Retrieved object was not a gzip archive", e); log.error("Retrieved object was not a gzip archive", e);
} }
@ -87,7 +82,7 @@ public class AsnManager implements Managed {
} }
@VisibleForTesting @VisibleForTesting
void handleAsnTableChanged(final InputStream inputStream) { void handleAsnTableChangedStream(final InputStream inputStream) {
try (final InputStreamReader reader = new InputStreamReader(inputStream)) { try (final InputStreamReader reader = new InputStreamReader(inputStream)) {
asnTable.set(new AsnTable(reader)); asnTable.set(new AsnTable(reader));
} catch (final Exception e) { } catch (final Exception e) {

View File

@ -5,11 +5,6 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import com.amazonaws.auth.InstanceProfileCredentialsProvider;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.S3Object;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import java.io.IOException; import java.io.IOException;
@ -22,6 +17,14 @@ import java.util.function.Consumer;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
/** /**
* An S3 object monitor watches a specific object in an S3 bucket and notifies a listener if that object changes. * An S3 object monitor watches a specific object in an S3 bucket and notifies a listener if that object changes.
@ -36,11 +39,11 @@ public class S3ObjectMonitor implements Managed {
private final Duration refreshInterval; private final Duration refreshInterval;
private ScheduledFuture<?> refreshFuture; private ScheduledFuture<?> refreshFuture;
private final Consumer<S3Object> changeListener; private final Consumer<ResponseInputStream<GetObjectResponse>> changeListener;
private final AtomicReference<String> lastETag = new AtomicReference<>(); private final AtomicReference<String> lastETag = new AtomicReference<>();
private final AmazonS3 s3Client; private final S3Client s3Client;
private static final Logger log = LoggerFactory.getLogger(S3ObjectMonitor.class); private static final Logger log = LoggerFactory.getLogger(S3ObjectMonitor.class);
@ -51,11 +54,11 @@ public class S3ObjectMonitor implements Managed {
final long maxObjectSize, final long maxObjectSize,
final ScheduledExecutorService refreshExecutorService, final ScheduledExecutorService refreshExecutorService,
final Duration refreshInterval, final Duration refreshInterval,
final Consumer<S3Object> changeListener) { final Consumer<ResponseInputStream<GetObjectResponse>> changeListener) {
this(AmazonS3ClientBuilder.standard() this(S3Client.builder()
.withCredentials(InstanceProfileCredentialsProvider.getInstance()) .region(Region.of(s3Region))
.withRegion(s3Region) .credentialsProvider(InstanceProfileCredentialsProvider.create())
.build(), .build(),
s3Bucket, s3Bucket,
objectKey, objectKey,
@ -67,13 +70,13 @@ public class S3ObjectMonitor implements Managed {
@VisibleForTesting @VisibleForTesting
S3ObjectMonitor( S3ObjectMonitor(
final AmazonS3 s3Client, final S3Client s3Client,
final String s3Bucket, final String s3Bucket,
final String objectKey, final String objectKey,
final long maxObjectSize, final long maxObjectSize,
final ScheduledExecutorService refreshExecutorService, final ScheduledExecutorService refreshExecutorService,
final Duration refreshInterval, final Duration refreshInterval,
final Consumer<S3Object> changeListener) { final Consumer<ResponseInputStream<GetObjectResponse>> changeListener) {
this.s3Client = s3Client; this.s3Client = s3Client;
this.s3Bucket = s3Bucket; this.s3Bucket = s3Bucket;
@ -92,8 +95,13 @@ public class S3ObjectMonitor implements Managed {
throw new RuntimeException("S3 object manager already started"); throw new RuntimeException("S3 object manager already started");
} }
// Run the first request immediately/blocking, then start subsequent calls.
log.info("Initial request for s3://{}/{}", s3Bucket, objectKey);
refresh();
refreshFuture = refreshExecutorService refreshFuture = refreshExecutorService
.scheduleAtFixedRate(this::refresh, refreshInterval.toMillis(), refreshInterval.toMillis(), TimeUnit.MILLISECONDS); .scheduleAtFixedRate(this::refresh, refreshInterval.toMillis(), refreshInterval.toMillis(),
TimeUnit.MILLISECONDS);
} }
@Override @Override
@ -106,21 +114,24 @@ public class S3ObjectMonitor implements Managed {
/** /**
* Immediately returns the monitored S3 object regardless of whether it has changed since it was last retrieved. * Immediately returns the monitored S3 object regardless of whether it has changed since it was last retrieved.
* *
* @return the current version of the monitored S3 object * @return the current version of the monitored S3 object. Caller should close() this upon completion.
*
* @throws IOException if the retrieved S3 object is larger than the configured maximum size * @throws IOException if the retrieved S3 object is larger than the configured maximum size
*/ */
public S3Object getObject() throws IOException { @VisibleForTesting
final S3Object s3Object = s3Client.getObject(s3Bucket, objectKey); ResponseInputStream<GetObjectResponse> getObject() throws IOException {
ResponseInputStream<GetObjectResponse> response = s3Client.getObject(GetObjectRequest.builder()
.key(objectKey)
.bucket(s3Bucket)
.build());
lastETag.set(s3Object.getObjectMetadata().getETag()); lastETag.set(response.response().eTag());
if (s3Object.getObjectMetadata().getContentLength() <= maxObjectSize) { if (response.response().contentLength() <= maxObjectSize) {
return s3Object; return response;
} else { } else {
log.warn("Object at s3://{}/{} has a size of {} bytes, which exceeds the maximum allowed size of {} bytes", log.warn("Object at s3://{}/{} has a size of {} bytes, which exceeds the maximum allowed size of {} bytes",
s3Bucket, objectKey, s3Object.getObjectMetadata().getContentLength(), maxObjectSize); s3Bucket, objectKey, response.response().contentLength(), maxObjectSize);
response.abort();
throw new IOException("S3 object too large"); throw new IOException("S3 object too large");
} }
} }
@ -132,18 +143,25 @@ public class S3ObjectMonitor implements Managed {
@VisibleForTesting @VisibleForTesting
void refresh() { void refresh() {
try { try {
final ObjectMetadata objectMetadata = s3Client.getObjectMetadata(s3Bucket, objectKey); final HeadObjectResponse objectMetadata = s3Client.headObject(HeadObjectRequest.builder()
.bucket(s3Bucket)
.key(objectKey)
.build());
final String initialETag = lastETag.get(); final String initialETag = lastETag.get();
final String refreshedETag = objectMetadata.getETag(); final String refreshedETag = objectMetadata.eTag();
if (!StringUtils.equals(initialETag, refreshedETag) && lastETag.compareAndSet(initialETag, refreshedETag)) { if (!StringUtils.equals(initialETag, refreshedETag) && lastETag.compareAndSet(initialETag, refreshedETag)) {
final S3Object s3Object = getObject(); final ResponseInputStream<GetObjectResponse> response = getObject();
log.info("Object at s3://{}/{} has changed; new eTag is {} and object size is {} bytes", log.info("Object at s3://{}/{} has changed; new eTag is {} and object size is {} bytes",
s3Bucket, objectKey, s3Object.getObjectMetadata().getETag(), s3Object.getObjectMetadata().getContentLength()); s3Bucket, objectKey, response.response().eTag(), response.response().contentLength());
changeListener.accept(s3Object); try {
changeListener.accept(response);
} finally {
response.close();
}
} }
} catch (final Exception e) { } catch (final Exception e) {
log.warn("Failed to refresh monitored object", e); log.warn("Failed to refresh monitored object", e);

View File

@ -7,14 +7,12 @@ package org.whispersystems.textsecuregcm.util;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import com.amazonaws.services.s3.model.S3Object;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.util.Collections; import java.util.Collections;
@ -25,6 +23,8 @@ import java.util.stream.Collectors;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
/** /**
* A utility for checking whether IP addresses belong to Tor exit nodes using the "bulk exit list." * A utility for checking whether IP addresses belong to Tor exit nodes using the "bulk exit list."
@ -58,12 +58,6 @@ public class TorExitNodeManager implements Managed {
@Override @Override
public synchronized void start() { public synchronized void start() {
try {
handleExitListChanged(exitListMonitor.getObject());
} catch (final Exception e) {
log.warn("Failed to load initial Tor exit node list", e);
}
exitListMonitor.start(); exitListMonitor.start();
} }
@ -76,12 +70,12 @@ public class TorExitNodeManager implements Managed {
return exitNodeAddresses.get().contains(address); return exitNodeAddresses.get().contains(address);
} }
private void handleExitListChanged(final S3Object exitList) { private void handleExitListChanged(final ResponseInputStream<GetObjectResponse> exitList) {
REFRESH_TIMER.record(() -> handleExitListChanged(exitList.getObjectContent())); REFRESH_TIMER.record(() -> handleExitListChanged(exitList));
} }
@VisibleForTesting @VisibleForTesting
void handleExitListChanged(final InputStream inputStream) { void handleExitListChangedStream(final InputStream inputStream) {
try (final BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { try (final BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
exitNodeAddresses.set(reader.lines().collect(Collectors.toSet())); exitNodeAddresses.set(reader.lines().collect(Collectors.toSet()));
} catch (final Exception e) { } catch (final Exception e) {

View File

@ -8,17 +8,11 @@ package org.whispersystems.textsecuregcm.util;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration;
import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.Inet4Address;
import java.nio.charset.StandardCharsets;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -34,7 +28,7 @@ class AsnManagerTest {
assertEquals(Optional.empty(), asnManager.getAsn("10.0.0.1")); assertEquals(Optional.empty(), asnManager.getAsn("10.0.0.1"));
try (final InputStream tableInputStream = getClass().getResourceAsStream("ip2asn-test.tsv")) { try (final InputStream tableInputStream = getClass().getResourceAsStream("ip2asn-test.tsv")) {
asnManager.handleAsnTableChanged(tableInputStream); asnManager.handleAsnTableChangedStream(tableInputStream);
} }
assertEquals(Optional.of(7922L), asnManager.getAsn("50.79.54.1")); assertEquals(Optional.of(7922L), asnManager.getAsn("50.79.54.1"));

View File

@ -1,11 +1,17 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.S3Object;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
@ -21,15 +27,13 @@ class S3ObjectMonitorTest {
@Test @Test
void refresh() { void refresh() {
final AmazonS3 s3Client = mock(AmazonS3.class); final S3Client s3Client = mock(S3Client.class);
final ObjectMetadata metadata = mock(ObjectMetadata.class);
final S3Object s3Object = mock(S3Object.class);
final String bucket = "s3bucket"; final String bucket = "s3bucket";
final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip";
//noinspection unchecked //noinspection unchecked
final Consumer<S3Object> listener = mock(Consumer.class); final Consumer<ResponseInputStream<GetObjectResponse>> listener = mock(Consumer.class);
final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( final S3ObjectMonitor objectMonitor = new S3ObjectMonitor(
s3Client, s3Client,
@ -40,28 +44,27 @@ class S3ObjectMonitorTest {
Duration.ofMinutes(1), Duration.ofMinutes(1),
listener); listener);
when(metadata.getETag()).thenReturn(UUID.randomUUID().toString()); String uuid = UUID.randomUUID().toString();
when(s3Object.getObjectMetadata()).thenReturn(metadata); when(s3Client.headObject(HeadObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(
when(s3Client.getObjectMetadata(bucket, objectKey)).thenReturn(metadata); HeadObjectResponse.builder().eTag(uuid).build());
when(s3Client.getObject(bucket, objectKey)).thenReturn(s3Object); ResponseInputStream<GetObjectResponse> ris = responseInputStreamFromString("abc", uuid);
when(s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(ris);
objectMonitor.refresh(); objectMonitor.refresh();
objectMonitor.refresh(); objectMonitor.refresh();
verify(listener).accept(s3Object); verify(listener).accept(ris);
} }
@Test @Test
void refreshAfterGet() throws IOException { void refreshAfterGet() throws IOException {
final AmazonS3 s3Client = mock(AmazonS3.class); final S3Client s3Client = mock(S3Client.class);
final ObjectMetadata metadata = mock(ObjectMetadata.class);
final S3Object s3Object = mock(S3Object.class);
final String bucket = "s3bucket"; final String bucket = "s3bucket";
final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip";
//noinspection unchecked //noinspection unchecked
final Consumer<S3Object> listener = mock(Consumer.class); final Consumer<ResponseInputStream<GetObjectResponse>> listener = mock(Consumer.class);
final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( final S3ObjectMonitor objectMonitor = new S3ObjectMonitor(
s3Client, s3Client,
@ -72,29 +75,34 @@ class S3ObjectMonitorTest {
Duration.ofMinutes(1), Duration.ofMinutes(1),
listener); listener);
when(metadata.getETag()).thenReturn(UUID.randomUUID().toString()); String uuid = UUID.randomUUID().toString();
when(s3Object.getObjectMetadata()).thenReturn(metadata); when(s3Client.headObject(HeadObjectRequest.builder().key(objectKey).bucket(bucket).build()))
when(s3Client.getObjectMetadata(bucket, objectKey)).thenReturn(metadata); .thenReturn(HeadObjectResponse.builder().eTag(uuid).build());
when(s3Client.getObject(bucket, objectKey)).thenReturn(s3Object); ResponseInputStream<GetObjectResponse> responseInputStream = responseInputStreamFromString("abc", uuid);
when(s3Client.getObject(GetObjectRequest.builder().key(objectKey).bucket(bucket).build())).thenReturn(responseInputStream);
objectMonitor.getObject(); objectMonitor.getObject();
objectMonitor.refresh(); objectMonitor.refresh();
verify(listener, never()).accept(s3Object); verify(listener, never()).accept(responseInputStream);
}
private ResponseInputStream<GetObjectResponse> responseInputStreamFromString(String s, String etag) {
byte[] bytes = s.getBytes(StandardCharsets.UTF_8);
AbortableInputStream ais = AbortableInputStream.create(new ByteArrayInputStream(bytes));
return new ResponseInputStream<>(GetObjectResponse.builder().contentLength((long) bytes.length).eTag(etag).build(), ais);
} }
@Test @Test
void refreshOversizedObject() { void refreshOversizedObject() {
final AmazonS3 s3Client = mock(AmazonS3.class); final S3Client s3Client = mock(S3Client.class);
final ObjectMetadata metadata = mock(ObjectMetadata.class);
final S3Object s3Object = mock(S3Object.class);
final String bucket = "s3bucket"; final String bucket = "s3bucket";
final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip"; final String objectKey = "greatest-smooth-jazz-hits-of-all-time.zip";
final long maxObjectSize = 16 * 1024 * 1024; final long maxObjectSize = 16 * 1024 * 1024;
//noinspection unchecked //noinspection unchecked
final Consumer<S3Object> listener = mock(Consumer.class); final Consumer<ResponseInputStream<GetObjectResponse>> listener = mock(Consumer.class);
final S3ObjectMonitor objectMonitor = new S3ObjectMonitor( final S3ObjectMonitor objectMonitor = new S3ObjectMonitor(
s3Client, s3Client,
@ -105,11 +113,11 @@ class S3ObjectMonitorTest {
Duration.ofMinutes(1), Duration.ofMinutes(1),
listener); listener);
when(metadata.getETag()).thenReturn(UUID.randomUUID().toString()); String uuid = UUID.randomUUID().toString();
when(metadata.getContentLength()).thenReturn(maxObjectSize + 1); when(s3Client.headObject(HeadObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(
when(s3Object.getObjectMetadata()).thenReturn(metadata); HeadObjectResponse.builder().eTag(uuid).contentLength(maxObjectSize+1).build());
when(s3Client.getObjectMetadata(bucket, objectKey)).thenReturn(metadata); ResponseInputStream<GetObjectResponse> ris = responseInputStreamFromString("a".repeat((int) maxObjectSize+1), uuid);
when(s3Client.getObject(bucket, objectKey)).thenReturn(s3Object); when(s3Client.getObject(GetObjectRequest.builder().bucket(bucket).key(objectKey).build())).thenReturn(ris);
objectMonitor.refresh(); objectMonitor.refresh();

View File

@ -29,7 +29,7 @@ public class TorExitNodeManagerTest extends AbstractRedisClusterTest {
assertFalse(torExitNodeManager.isTorExitNode("10.0.0.1")); assertFalse(torExitNodeManager.isTorExitNode("10.0.0.1"));
assertFalse(torExitNodeManager.isTorExitNode("10.0.0.2")); assertFalse(torExitNodeManager.isTorExitNode("10.0.0.2"));
torExitNodeManager.handleExitListChanged( torExitNodeManager.handleExitListChangedStream(
new ByteArrayInputStream("10.0.0.1\n10.0.0.2".getBytes(StandardCharsets.UTF_8))); new ByteArrayInputStream("10.0.0.1\n10.0.0.2".getBytes(StandardCharsets.UTF_8)));
assertTrue(torExitNodeManager.isTorExitNode("10.0.0.1")); assertTrue(torExitNodeManager.isTorExitNode("10.0.0.1"));