diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 61a16d1eb..3859c09d9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -15,7 +15,9 @@ import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; +import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -40,7 +42,6 @@ import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; -import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount; @@ -59,6 +60,8 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.util.Util; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v2/keys") @@ -237,43 +240,28 @@ public class KeysController { io.micrometer.core.instrument.Tag.of("wildcardDeviceId", String.valueOf("*".equals(deviceId))))) .increment(); - final List devices = parseDeviceId(deviceId, target); - final List responseItems = new ArrayList<>(devices.size()); - - final List> tasks = devices.stream().map(device -> { - final CompletableFuture> unsignedEcPreKeyFuture = - keys.takeEC(targetIdentifier.uuid(), device.getId()); - - final CompletableFuture> signedEcPreKeyFuture = - keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()); - - final CompletableFuture> pqPreKeyFuture = returnPqKey - ? keys.takePQ(targetIdentifier.uuid(), device.getId()) - : CompletableFuture.completedFuture(Optional.empty()); - - return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture) - .thenAccept(ignored -> { - final KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null); - final ECPreKey unsignedEcPreKey = unsignedEcPreKeyFuture.join().orElse(null); - final ECSignedPreKey signedEcPreKey = signedEcPreKeyFuture.join().orElse(null); - - if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) { - final int registrationId = switch (targetIdentifier.identityType()) { - case ACI -> device.getRegistrationId(); - case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); - }; - - synchronized (responseItems) { - responseItems.add( - new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey, - pqPreKey)); - } - } - }); - }) - .toList(); - - CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join(); + final List responseItems = Flux.fromIterable(parseDeviceId(deviceId, target)) + .flatMap(device -> Mono.zip( + Mono.just(device), + Mono.fromFuture(() -> keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())), + Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())), + Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()) + : CompletableFuture.>completedFuture(Optional.empty())) + )).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent()) + .map(deviceAndKeys -> { + final Device device = deviceAndKeys.getT1(); + final int registrationId = switch (targetIdentifier.identityType()) { + case ACI -> device.getRegistrationId(); + case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); + }; + return new PreKeyResponseItem(device.getId(), registrationId, + deviceAndKeys.getT2().orElse(null), + deviceAndKeys.getT3().orElse(null), + deviceAndKeys.getT4().orElse(null)); + }).collectList() + .timeout(Duration.ofSeconds(30)) + .blockOptional() + .orElse(Collections.emptyList()); final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType());