Wait for outstanding requests to be resolved before shutting down the directory queue.

This commit is contained in:
Jon Chambers 2021-07-26 12:56:33 -04:00 committed by Jon Chambers
parent 34dbff6786
commit 3608c5bfb0
3 changed files with 67 additions and 1 deletions

View File

@ -515,6 +515,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(currencyManager);
environment.lifecycle().manage(torExitNodeManager);
environment.lifecycle().manage(asnManager);
environment.lifecycle().manage(directoryQueue);
StaticCredentialsProvider cdnCredentialsProvider = StaticCredentialsProvider
.create(AwsBasicCredentials.create(

View File

@ -11,9 +11,13 @@ import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.SqsConfiguration;
@ -28,7 +32,7 @@ import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
public class DirectoryQueue {
public class DirectoryQueue implements Managed {
private static final Logger logger = LoggerFactory.getLogger(DirectoryQueue.class);
@ -40,6 +44,8 @@ public class DirectoryQueue {
private final List<String> queueUrls;
private final SqsAsyncClient sqs;
private final Set<SendMessageRequest> outstandingRequests = Collections.newSetFromMap(new IdentityHashMap<>());
private enum UpdateAction {
ADD("add"),
DELETE("delete");
@ -73,6 +79,21 @@ public class DirectoryQueue {
this.sqs = sqs;
}
@Override
public void start() throws Exception {
}
@Override
public void stop() throws Exception {
synchronized (outstandingRequests) {
while (!outstandingRequests.isEmpty()) {
outstandingRequests.wait();
}
}
sqs.close();
}
public boolean isDiscoverable(final Account account) {
return account.isEnabled() && account.isDiscoverableByPhoneNumber();
}
@ -101,6 +122,10 @@ public class DirectoryQueue {
))
.build();
synchronized (outstandingRequests) {
outstandingRequests.add(request);
}
sqs.sendMessage(request).whenComplete((response, cause) -> {
try {
if (cause instanceof SdkServiceException) {
@ -113,6 +138,11 @@ public class DirectoryQueue {
logger.warn("sqs unexpected error", cause);
}
} finally {
synchronized (outstandingRequests) {
outstandingRequests.remove(request);
outstandingRequests.notifyAll();
}
timerContext.close();
}
});

View File

@ -6,6 +6,8 @@
package org.whispersystems.textsecuregcm.sqs;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@ -15,6 +17,8 @@ import static org.mockito.Mockito.when;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -89,4 +93,35 @@ public class DirectoryQueueTest {
sendMessageRequest.messageAttributes().get("action"));
}
}
@Test
void testStop() {
final CompletableFuture<SendMessageResponse> sendMessageFuture = new CompletableFuture<>();
when(sqsAsyncClient.sendMessage(any(SendMessageRequest.class))).thenReturn(sendMessageFuture);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqsAsyncClient);
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.isEnabled()).thenReturn(true);
when(account.isDiscoverableByPhoneNumber()).thenReturn(true);
directoryQueue.refreshAccount(account);
final CompletableFuture<Boolean> stopFuture = CompletableFuture.supplyAsync(() -> {
try {
directoryQueue.stop();
return true;
} catch (final Exception e) {
return false;
}
});
assertThrows(TimeoutException.class, () -> stopFuture.get(1, TimeUnit.SECONDS),
"Directory queue should not finish shutting down until all outstanding requests are resolved");
sendMessageFuture.complete(SendMessageResponse.builder().build());
assertTrue(stopFuture.join());
}
}