diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java index 22680a266..83f540773 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java @@ -10,6 +10,7 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; +import java.time.Instant; import java.util.Optional; import java.util.function.BiConsumer; import org.apache.commons.lang3.StringUtils; @@ -131,7 +132,11 @@ public class PushNotificationManager { if (result.unregistered() && pushNotification.destination() != null && pushNotification.destinationDevice() != null) { - handleDeviceUnregistered(pushNotification.destination(), pushNotification.destinationDevice()); + + handleDeviceUnregistered(pushNotification.destination(), + pushNotification.destinationDevice(), + pushNotification.tokenType(), + result.unregisteredTimestamp()); } if (result.accepted() && @@ -164,28 +169,53 @@ public class PushNotificationManager { }; } - private void handleDeviceUnregistered(final Account account, final Device device) { - if (StringUtils.isNotBlank(device.getGcmId())) { - final String originalFcmId = device.getGcmId(); + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + private void handleDeviceUnregistered(final Account account, + final Device device, + final PushNotification.TokenType tokenType, + final Optional maybeTokenInvalidationTimestamp) { - // Reread the account to avoid marking the caller's account as stale. The consumers of this class tend to - // promise not to modify accounts. There's no need to force the caller to be considered mutable just for - // updating an uninstalled feedback timestamp though. - final Optional rereadAccount = accountsManager.getByAccountIdentifier(account.getUuid()); - if (rereadAccount.isEmpty()) { - // Don't bother removing the token; the account is gone - return; + final boolean tokenExpired = maybeTokenInvalidationTimestamp.map(tokenInvalidationTimestamp -> + tokenInvalidationTimestamp.isAfter(Instant.ofEpochMilli(device.getPushTimestamp()))).orElse(true); + + if (tokenExpired) { + if (tokenType == PushNotification.TokenType.APN || tokenType == PushNotification.TokenType.APN_VOIP) { + apnPushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); } - rereadAccount.get().getDevice(device.getId()).ifPresent(rereadDevice -> - accountsManager.updateDevice(rereadAccount.get(), device.getId(), d -> { - // Don't clear the token if it's already changed - if (originalFcmId.equals(d.getGcmId())) { - d.setGcmId(null); - } - })); - } else { - apnPushNotificationScheduler.cancelScheduledNotifications(account, device).whenComplete(logErrors()); + clearPushToken(account, device, tokenType); } } + + private void clearPushToken(final Account account, final Device device, final PushNotification.TokenType tokenType) { + final String originalToken = getPushToken(device, tokenType); + + if (originalToken == null) { + return; + } + + // Reread the account to avoid marking the caller's account as stale. The consumers of this class tend to + // promise not to modify accounts. There's no need to force the caller to be considered mutable just for + // updating an uninstalled feedback timestamp though. + accountsManager.getByAccountIdentifier(account.getUuid()).ifPresent(rereadAccount -> + rereadAccount.getDevice(device.getId()).ifPresent(rereadDevice -> + accountsManager.updateDevice(rereadAccount, device.getId(), d -> { + // Don't clear the token if it's already changed + if (originalToken.equals(getPushToken(d, tokenType))) { + switch (tokenType) { + case FCM -> d.setGcmId(null); + case APN -> d.setApnId(null); + case APN_VOIP -> d.setVoipApnId(null); + } + } + }))); + } + + private static String getPushToken(final Device device, final PushNotification.TokenType tokenType) { + return switch (tokenType) { + case FCM -> device.getGcmId(); + case APN -> device.getApnId(); + case APN_VOIP -> device.getVoipApnId(); + }; + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java index e64adf5b7..9c49945fa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushNotificationManagerTest.java @@ -15,6 +15,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import com.google.common.net.HttpHeaders; +import java.time.Instant; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -244,9 +245,13 @@ class PushNotificationManagerTest { void testSendNotificationUnregisteredApn() { final Account account = mock(Account.class); final Device device = mock(Device.class); - + final UUID aci = UUID.randomUUID(); when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(device.getApnId()).thenReturn("apns-token"); + when(device.getVoipApnId()).thenReturn("apns-voip-token"); when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device)); + when(account.getUuid()).thenReturn(aci); + when(accountsManager.getByAccountIdentifier(aci)).thenReturn(Optional.of(account)); final PushNotification pushNotification = new PushNotification( "token", PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true); @@ -260,11 +265,45 @@ class PushNotificationManagerTest { pushNotificationManager.sendNotification(pushNotification); verifyNoInteractions(fcmSender); - verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); - verify(device, never()).setGcmId(any()); + verify(accountsManager).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); + verify(device).setVoipApnId(null); + verify(device, never()).setApnId(any()); verify(apnPushNotificationScheduler).cancelScheduledNotifications(account, device); } + @Test + void testSendNotificationUnregisteredApnTokenUpdated() { + final Instant tokenTimestamp = Instant.now(); + + final Account account = mock(Account.class); + final Device device = mock(Device.class); + final UUID aci = UUID.randomUUID(); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(device.getApnId()).thenReturn("apns-token"); + when(device.getVoipApnId()).thenReturn("apns-voip-token"); + when(device.getPushTimestamp()).thenReturn(tokenTimestamp.toEpochMilli()); + when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device)); + when(account.getUuid()).thenReturn(aci); + when(accountsManager.getByAccountIdentifier(aci)).thenReturn(Optional.of(account)); + + final PushNotification pushNotification = new PushNotification( + "token", PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true); + + when(apnSender.sendNotification(pushNotification)) + .thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, Optional.empty(), true, Optional.of(tokenTimestamp.minusSeconds(60))))); + + when(apnPushNotificationScheduler.cancelScheduledNotifications(account, device)) + .thenReturn(CompletableFuture.completedFuture(null)); + + pushNotificationManager.sendNotification(pushNotification); + + verifyNoInteractions(fcmSender); + verify(accountsManager, never()).updateDevice(eq(account), eq(Device.PRIMARY_ID), any()); + verify(device, never()).setVoipApnId(any()); + verify(device, never()).setApnId(any()); + verify(apnPushNotificationScheduler, never()).cancelScheduledNotifications(account, device); + } + @Test void testHandleMessagesRetrieved() { final UUID accountIdentifier = UUID.randomUUID();