Refactor key-fetching to be reactive

This commit is contained in:
Jon Chambers 2023-12-13 12:41:51 -05:00 committed by Jon Chambers
parent 4ce060a963
commit 609c901867
1 changed files with 26 additions and 38 deletions

View File

@ -15,7 +15,9 @@ import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -40,7 +42,6 @@ import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
@ -59,6 +60,8 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys") @Path("/v2/keys")
@ -237,43 +240,28 @@ public class KeysController {
io.micrometer.core.instrument.Tag.of("wildcardDeviceId", String.valueOf("*".equals(deviceId))))) io.micrometer.core.instrument.Tag.of("wildcardDeviceId", String.valueOf("*".equals(deviceId)))))
.increment(); .increment();
final List<Device> devices = parseDeviceId(deviceId, target); final List<PreKeyResponseItem> responseItems = Flux.fromIterable(parseDeviceId(deviceId, target))
final List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size()); .flatMap(device -> Mono.zip(
Mono.just(device),
final List<CompletableFuture<Void>> tasks = devices.stream().map(device -> { Mono.fromFuture(() -> keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())),
final CompletableFuture<Optional<ECPreKey>> unsignedEcPreKeyFuture = Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())),
keys.takeEC(targetIdentifier.uuid(), device.getId()); Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.<Optional<KEMSignedPreKey>>completedFuture(Optional.empty()))
final CompletableFuture<Optional<ECSignedPreKey>> signedEcPreKeyFuture = )).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent())
keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()); .map(deviceAndKeys -> {
final Device device = deviceAndKeys.getT1();
final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture = returnPqKey
? keys.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.completedFuture(Optional.empty());
return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture)
.thenAccept(ignored -> {
final KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null);
final ECPreKey unsignedEcPreKey = unsignedEcPreKeyFuture.join().orElse(null);
final ECSignedPreKey signedEcPreKey = signedEcPreKeyFuture.join().orElse(null);
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) { final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId(); case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
}; };
return new PreKeyResponseItem(device.getId(), registrationId,
synchronized (responseItems) { deviceAndKeys.getT2().orElse(null),
responseItems.add( deviceAndKeys.getT3().orElse(null),
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey, deviceAndKeys.getT4().orElse(null));
pqPreKey)); }).collectList()
} .timeout(Duration.ofSeconds(30))
} .blockOptional()
}); .orElse(Collections.emptyList());
})
.toList();
CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();
final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType()); final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType());