Support for push preauth

This commit is contained in:
Moxie Marlinspike 2019-06-06 17:31:07 -07:00
parent 18037bb484
commit 4fdbe9b9ff
22 changed files with 391 additions and 103 deletions

View File

@ -247,7 +247,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter))); DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter)));
environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))); environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)));
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, abusiveHostRules, rateLimiters, smsSender, directoryQueue, messagesManager, turnTokenGenerator, config.getTestDevices(), recaptchaClient)); environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, abusiveHostRules, rateLimiters, smsSender, directoryQueue, messagesManager, turnTokenGenerator, config.getTestDevices(), recaptchaClient, gcmSender, apnSender));
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, directoryQueue, rateLimiters, config.getMaxDevices())); environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, directoryQueue, rateLimiters, config.getMaxDevices()));
environment.jersey().register(new DirectoryController(rateLimiters, directory, directoryCredentialsGenerator)); environment.jersey().register(new DirectoryController(rateLimiters, directory, directoryCredentialsGenerator));
environment.jersey().register(new ProvisioningController(rateLimiters, pushSender)); environment.jersey().register(new ProvisioningController(rateLimiters, pushSender));

View File

@ -2,6 +2,8 @@ package org.whispersystems.textsecuregcm.auth;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.util.Util;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -13,11 +15,15 @@ public class StoredVerificationCode {
@JsonProperty @JsonProperty
private long timestamp; private long timestamp;
@JsonProperty
private String pushCode;
public StoredVerificationCode() {} public StoredVerificationCode() {}
public StoredVerificationCode(String code, long timestamp) { public StoredVerificationCode(String code, long timestamp, String pushCode) {
this.code = code; this.code = code;
this.timestamp = timestamp; this.timestamp = timestamp;
this.pushCode = pushCode;
} }
public String getCode() { public String getCode() {
@ -28,8 +34,16 @@ public class StoredVerificationCode {
return timestamp; return timestamp;
} }
public String getPushCode() {
return pushCode;
}
public boolean isValid(String theirCodeString) { public boolean isValid(String theirCodeString) {
if (timestamp + TimeUnit.MINUTES.toMillis(30) < System.currentTimeMillis()) { if (timestamp + TimeUnit.MINUTES.toMillis(10) < System.currentTimeMillis()) {
return false;
}
if (Util.isEmpty(code) || Util.isEmpty(theirCodeString)) {
return false; return false;
} }
@ -38,4 +52,5 @@ public class StoredVerificationCode {
return MessageDigest.isEqual(ourCode, theirCode); return MessageDigest.isEqual(ourCode, theirCode);
} }
} }

View File

@ -37,6 +37,10 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.RegistrationLock; import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.push.GCMSender;
import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient; import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
@ -48,6 +52,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.VerificationCode;
@ -90,16 +95,18 @@ public class AccountController {
private final Meter captchaFailureMeter = metricRegistry.meter(name(AccountController.class, "captcha_failure" )); private final Meter captchaFailureMeter = metricRegistry.meter(name(AccountController.class, "captcha_failure" ));
private final PendingAccountsManager pendingAccounts; private final PendingAccountsManager pendingAccounts;
private final AccountsManager accounts; private final AccountsManager accounts;
private final AbusiveHostRules abusiveHostRules; private final AbusiveHostRules abusiveHostRules;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final SmsSender smsSender; private final SmsSender smsSender;
private final DirectoryQueue directoryQueue; private final DirectoryQueue directoryQueue;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final TurnTokenGenerator turnTokenGenerator; private final TurnTokenGenerator turnTokenGenerator;
private final Map<String, Integer> testDevices; private final Map<String, Integer> testDevices;
private final RecaptchaClient recaptchaClient; private final RecaptchaClient recaptchaClient;
private final GCMSender gcmSender;
private final APNSender apnSender;
public AccountController(PendingAccountsManager pendingAccounts, public AccountController(PendingAccountsManager pendingAccounts,
AccountsManager accounts, AccountsManager accounts,
@ -110,7 +117,9 @@ public class AccountController {
MessagesManager messagesManager, MessagesManager messagesManager,
TurnTokenGenerator turnTokenGenerator, TurnTokenGenerator turnTokenGenerator,
Map<String, Integer> testDevices, Map<String, Integer> testDevices,
RecaptchaClient recaptchaClient) RecaptchaClient recaptchaClient,
GCMSender gcmSender,
APNSender apnSender)
{ {
this.pendingAccounts = pendingAccounts; this.pendingAccounts = pendingAccounts;
this.accounts = accounts; this.accounts = accounts;
@ -122,6 +131,41 @@ public class AccountController {
this.testDevices = testDevices; this.testDevices = testDevices;
this.turnTokenGenerator = turnTokenGenerator; this.turnTokenGenerator = turnTokenGenerator;
this.recaptchaClient = recaptchaClient; this.recaptchaClient = recaptchaClient;
this.gcmSender = gcmSender;
this.apnSender = apnSender;
}
@Timed
@GET
@Path("/{type}/preauth/{token}/{number}")
public Response getPreAuth(@PathParam("type") String pushType,
@PathParam("token") String pushToken,
@PathParam("number") String number)
{
if (!"apn".equals(pushType) && !"fcm".equals(pushType)) {
return Response.status(400).build();
}
if (!Util.isValidNumber(number)) {
return Response.status(400).build();
}
String pushChallenge = generatePushChallenge();
StoredVerificationCode storedVerificationCode = new StoredVerificationCode(null,
System.currentTimeMillis(),
pushChallenge);
pendingAccounts.store(number, storedVerificationCode);
if ("fcm".equals(pushType)) {
gcmSender.sendMessage(new GcmMessage(pushToken, number, 0, GcmMessage.Type.CHALLENGE, Optional.of(storedVerificationCode.getPushCode())));
} else if ("apn".equals(pushType)) {
apnSender.sendMessage(new ApnMessage(pushToken, number, 0, true, Optional.of(storedVerificationCode.getPushCode())));
} else {
throw new AssertionError();
}
return Response.ok().build();
} }
@Timed @Timed
@ -132,7 +176,8 @@ public class AccountController {
@HeaderParam("X-Forwarded-For") String forwardedFor, @HeaderParam("X-Forwarded-For") String forwardedFor,
@HeaderParam("Accept-Language") Optional<String> locale, @HeaderParam("Accept-Language") Optional<String> locale,
@QueryParam("client") Optional<String> client, @QueryParam("client") Optional<String> client,
@QueryParam("captcha") Optional<String> captcha) @QueryParam("captcha") Optional<String> captcha,
@QueryParam("challenge") Optional<String> pushChallenge)
throws RateLimitExceededException throws RateLimitExceededException
{ {
if (!Util.isValidNumber(number)) { if (!Util.isValidNumber(number)) {
@ -145,7 +190,8 @@ public class AccountController {
.reduce((a, b) -> b) .reduce((a, b) -> b)
.orElseThrow(); .orElseThrow();
CaptchaRequirement requirement = requiresCaptcha(number, transport, forwardedFor, requester, captcha); Optional<StoredVerificationCode> storedChallenge = pendingAccounts.getCodeForNumber(number);
CaptchaRequirement requirement = requiresCaptcha(number, transport, forwardedFor, requester, captcha, storedChallenge, pushChallenge);
if (requirement.isCaptchaRequired()) { if (requirement.isCaptchaRequired()) {
if (requirement.isAutoBlock() && shouldAutoBlock(requester)) { if (requirement.isAutoBlock() && shouldAutoBlock(requester)) {
@ -170,7 +216,8 @@ public class AccountController {
VerificationCode verificationCode = generateVerificationCode(number); VerificationCode verificationCode = generateVerificationCode(number);
StoredVerificationCode storedVerificationCode = new StoredVerificationCode(verificationCode.getVerificationCode(), StoredVerificationCode storedVerificationCode = new StoredVerificationCode(verificationCode.getVerificationCode(),
System.currentTimeMillis()); System.currentTimeMillis(),
storedChallenge.map(StoredVerificationCode::getPushCode).orElse(null));
pendingAccounts.store(number, storedVerificationCode); pendingAccounts.store(number, storedVerificationCode);
@ -397,7 +444,10 @@ public class AccountController {
} }
private CaptchaRequirement requiresCaptcha(String number, String transport, String forwardedFor, private CaptchaRequirement requiresCaptcha(String number, String transport, String forwardedFor,
String requester, Optional<String> captchaToken) String requester,
Optional<String> captchaToken,
Optional<StoredVerificationCode> storedVerificationCode,
Optional<String> pushChallenge)
{ {
if (captchaToken.isPresent()) { if (captchaToken.isPresent()) {
@ -412,6 +462,14 @@ public class AccountController {
} }
} }
if (pushChallenge.isPresent()) {
Optional<String> storedPushChallenge = storedVerificationCode.map(StoredVerificationCode::getPushCode);
if (!pushChallenge.get().equals(storedPushChallenge.orElse(null))) {
return new CaptchaRequirement(true, false);
}
}
List<AbusiveHostRule> abuseRules = abusiveHostRules.getAbusiveHostRulesFor(requester); List<AbusiveHostRule> abuseRules = abusiveHostRules.getAbusiveHostRulesFor(requester);
for (AbusiveHostRule abuseRule : abuseRules) { for (AbusiveHostRule abuseRule : abuseRules) {
@ -493,7 +551,8 @@ public class AccountController {
pendingAccounts.remove(number); pendingAccounts.remove(number);
} }
@VisibleForTesting protected VerificationCode generateVerificationCode(String number) { @VisibleForTesting protected
VerificationCode generateVerificationCode(String number) {
if (testDevices.containsKey(number)) { if (testDevices.containsKey(number)) {
return new VerificationCode(testDevices.get(number)); return new VerificationCode(testDevices.get(number));
} }
@ -503,6 +562,14 @@ public class AccountController {
return new VerificationCode(randomInt); return new VerificationCode(randomInt);
} }
private String generatePushChallenge() {
SecureRandom random = new SecureRandom();
byte[] challenge = new byte[16];
random.nextBytes(challenge);
return Hex.toStringCondensed(challenge);
}
private static class CaptchaRequirement { private static class CaptchaRequirement {
private final boolean captchaRequired; private final boolean captchaRequired;
private final boolean autoBlock; private final boolean autoBlock;

View File

@ -144,7 +144,8 @@ public class DeviceController {
VerificationCode verificationCode = generateVerificationCode(); VerificationCode verificationCode = generateVerificationCode();
StoredVerificationCode storedVerificationCode = new StoredVerificationCode(verificationCode.getVerificationCode(), StoredVerificationCode storedVerificationCode = new StoredVerificationCode(verificationCode.getVerificationCode(),
System.currentTimeMillis()); System.currentTimeMillis(),
null);
pendingDevices.store(account.getNumber(), storedVerificationCode); pendingDevices.store(account.getNumber(), storedVerificationCode);

View File

@ -94,6 +94,8 @@ public class APNSender implements Managed {
Futures.addCallback(future, new FutureCallback<ApnResult>() { Futures.addCallback(future, new FutureCallback<ApnResult>() {
@Override @Override
public void onSuccess(@Nullable ApnResult result) { public void onSuccess(@Nullable ApnResult result) {
if (message.getChallengeData().isPresent()) return;
if (result == null) { if (result == null) {
logger.warn("*** RECEIVED NULL APN RESULT ***"); logger.warn("*** RECEIVED NULL APN RESULT ***");
} else if (result.getStatus() == ApnResult.Status.NO_SUCH_USER) { } else if (result.getStatus() == ApnResult.Status.NO_SUCH_USER) {

View File

@ -157,7 +157,7 @@ public class ApnFallbackManager implements Managed, Runnable {
continue; continue;
} }
apnSender.sendMessage(new ApnMessage(apnId, separated.get().first(), separated.get().second(), true)); apnSender.sendMessage(new ApnMessage(apnId, separated.get().first(), separated.get().second(), true, Optional.empty()));
retry.mark(); retry.mark();
} }

View File

@ -1,20 +1,28 @@
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import com.google.common.annotations.VisibleForTesting;
import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class ApnMessage { public class ApnMessage {
public static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}"; public static final String APN_NOTIFICATION_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}";
public static final long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L; public static final String APN_CHALLENGE_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}, \"challenge\" : \"%s\"}";
public static final long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L;
private final String apnId; private final String apnId;
private final String number; private final String number;
private final long deviceId; private final long deviceId;
private final boolean isVoip; private final boolean isVoip;
private final Optional<String> challengeData;
public ApnMessage(String apnId, String number, long deviceId, boolean isVoip) { public ApnMessage(String apnId, String number, long deviceId, boolean isVoip, Optional<String> challengeData) {
this.apnId = apnId; this.apnId = apnId;
this.number = number; this.number = number;
this.deviceId = deviceId; this.deviceId = deviceId;
this.isVoip = isVoip; this.isVoip = isVoip;
this.challengeData = challengeData;
} }
public boolean isVoip() { public boolean isVoip() {
@ -26,7 +34,13 @@ public class ApnMessage {
} }
public String getMessage() { public String getMessage() {
return APN_PAYLOAD; if (!challengeData.isPresent()) return APN_NOTIFICATION_PAYLOAD;
else return String.format(APN_CHALLENGE_PAYLOAD, challengeData.get());
}
@VisibleForTesting
public Optional<String> getChallengeData() {
return challengeData;
} }
public long getExpirationTime() { public long getExpirationTime() {

View File

@ -66,14 +66,22 @@ public class GCMSender implements Managed {
.withDestination(message.getGcmId()) .withDestination(message.getGcmId())
.withPriority("high"); .withPriority("high");
String key = message.isReceipt() ? "receipt" : "notification"; String key;
Message request = builder.withDataPart(key, "").build();
switch (message.getType()) {
case RECEIPT: key = "receipt"; break;
case NOTIFICATION: key = "notification"; break;
case CHALLENGE: key = "challenge"; break;
default: throw new AssertionError();
}
Message request = builder.withDataPart(key, message.getData().orElse("")).build();
CompletableFuture<Result> future = signalSender.send(request); CompletableFuture<Result> future = signalSender.send(request);
markOutboundMeter(key); markOutboundMeter(key);
future.handle((result, throwable) -> { future.handle((result, throwable) -> {
if (result != null) { if (result != null && message.getType() != GcmMessage.Type.CHALLENGE) {
if (result.isUnregistered() || result.isInvalidRegistrationId()) { if (result.isUnregistered() || result.isInvalidRegistrationId()) {
executor.submit(() -> handleBadRegistration(message)); executor.submit(() -> handleBadRegistration(message));
} else if (result.hasCanonicalRegistrationId()) { } else if (result.hasCanonicalRegistrationId()) {

View File

@ -1,17 +1,27 @@
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class GcmMessage { public class GcmMessage {
private final String gcmId; public enum Type {
private final String number; RECEIPT, NOTIFICATION, CHALLENGE
private final int deviceId; }
private final boolean receipt;
public GcmMessage(String gcmId, String number, int deviceId, boolean receipt) { private final String gcmId;
this.gcmId = gcmId; private final String number;
this.number = number; private final int deviceId;
this.deviceId = deviceId; private final Type type;
this.receipt = receipt; private final Optional<String> data;
public GcmMessage(String gcmId, String number, int deviceId, Type type, Optional<String> data) {
this.gcmId = gcmId;
this.number = number;
this.deviceId = deviceId;
this.type = type;
this.data = data;
} }
public String getGcmId() { public String getGcmId() {
@ -22,11 +32,16 @@ public class GcmMessage {
return number; return number;
} }
public boolean isReceipt() { public Type getType() {
return receipt; return type;
} }
public int getDeviceId() { public int getDeviceId() {
return deviceId; return deviceId;
} }
public Optional<String> getData() {
return data;
}
} }

View File

@ -28,6 +28,7 @@ import org.whispersystems.textsecuregcm.util.BlockingThreadPoolExecutor;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import java.util.Optional;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
@ -105,7 +106,7 @@ public class PushSender implements Managed {
private void sendGcmNotification(Account account, Device device) { private void sendGcmNotification(Account account, Device device) {
GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(), GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(),
(int)device.getId(), false); (int)device.getId(), GcmMessage.Type.NOTIFICATION, Optional.empty());
gcmSender.sendMessage(gcmMessage); gcmSender.sendMessage(gcmMessage);
} }
@ -126,10 +127,10 @@ public class PushSender implements Managed {
} }
if (!Util.isEmpty(device.getVoipApnId())) { if (!Util.isEmpty(device.getVoipApnId())) {
apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), device.getId(), true); apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), device.getId(), true, Optional.empty());
RedisOperation.unchecked(() -> apnFallbackManager.schedule(account, device)); RedisOperation.unchecked(() -> apnFallbackManager.schedule(account, device));
} else { } else {
apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), device.getId(), false); apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), device.getId(), false, Optional.empty());
} }
apnSender.sendMessage(apnMessage); apnSender.sendMessage(apnMessage);

View File

@ -19,17 +19,10 @@ package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.PreparedBatch;
import org.jdbi.v3.core.statement.StatementContext;
import org.jdbi.v3.core.transaction.TransactionIsolationLevel;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper; import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
@ -49,17 +42,17 @@ public class PendingAccounts {
this.database.getDatabase().registerRowMapper(new StoredVerificationCodeRowMapper()); this.database.getDatabase().registerRowMapper(new StoredVerificationCodeRowMapper());
} }
public void insert(String number, String verificationCode, long timestamp) { public void insert(String number, String verificationCode, long timestamp, String pushCode) {
database.use(jdbi -> jdbi.useTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = insertTimer.time()) { try (Timer.Context ignored = insertTimer.time()) {
handle.createUpdate("DELETE FROM pending_accounts WHERE number = :number") handle.createUpdate("INSERT INTO pending_accounts (number, verification_code, timestamp, push_code) " +
.bind("number", number) "VALUES (:number, :verification_code, :timestamp, :push_code) " +
.execute(); "ON CONFLICT(number) DO UPDATE " +
"SET verification_code = EXCLUDED.verification_code, timestamp = EXCLUDED.timestamp, push_code = EXCLUDED.push_code")
handle.createUpdate("INSERT INTO pending_accounts (number, verification_code, timestamp) VALUES (:number, :verification_code, :timestamp)")
.bind("verification_code", verificationCode) .bind("verification_code", verificationCode)
.bind("timestamp", timestamp) .bind("timestamp", timestamp)
.bind("number", number) .bind("number", number)
.bind("push_code", pushCode)
.execute(); .execute();
} }
})); }));
@ -68,7 +61,7 @@ public class PendingAccounts {
public Optional<StoredVerificationCode> getCodeForNumber(String number) { public Optional<StoredVerificationCode> getCodeForNumber(String number) {
return database.with(jdbi ->jdbi.withHandle(handle -> { return database.with(jdbi ->jdbi.withHandle(handle -> {
try (Timer.Context ignored = getCodeForNumberTimer.time()) { try (Timer.Context ignored = getCodeForNumberTimer.time()) {
return handle.createQuery("SELECT verification_code, timestamp FROM pending_accounts WHERE number = :number") return handle.createQuery("SELECT verification_code, timestamp, push_code FROM pending_accounts WHERE number = :number")
.bind("number", number) .bind("number", number)
.mapTo(StoredVerificationCode.class) .mapTo(StoredVerificationCode.class)
.findFirst(); .findFirst();

View File

@ -48,7 +48,7 @@ public class PendingAccountsManager {
public void store(String number, StoredVerificationCode code) { public void store(String number, StoredVerificationCode code) {
memcacheSet(number, code); memcacheSet(number, code);
pendingAccounts.insert(number, code.getCode(), code.getTimestamp()); pendingAccounts.insert(number, code.getCode(), code.getTimestamp(), code.getPushCode());
} }
public void remove(String number) { public void remove(String number) {

View File

@ -57,7 +57,7 @@ public class PendingDevices {
public Optional<StoredVerificationCode> getCodeForNumber(String number) { public Optional<StoredVerificationCode> getCodeForNumber(String number) {
return database.with(jdbi -> jdbi.withHandle(handle -> { return database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context timer = getCodeForNumberTimer.time()) { try (Timer.Context timer = getCodeForNumberTimer.time()) {
return handle.createQuery("SELECT verification_code, timestamp FROM pending_devices WHERE number = :number") return handle.createQuery("SELECT verification_code, timestamp, NULL as push_code FROM pending_devices WHERE number = :number")
.bind("number", number) .bind("number", number)
.mapTo(StoredVerificationCode.class) .mapTo(StoredVerificationCode.class)
.findFirst(); .findFirst();

View File

@ -12,6 +12,7 @@ public class StoredVerificationCodeRowMapper implements RowMapper<StoredVerifica
@Override @Override
public StoredVerificationCode map(ResultSet resultSet, StatementContext ctx) throws SQLException { public StoredVerificationCode map(ResultSet resultSet, StatementContext ctx) throws SQLException {
return new StoredVerificationCode(resultSet.getString("verification_code"), return new StoredVerificationCode(resultSet.getString("verification_code"),
resultSet.getLong("timestamp")); resultSet.getLong("timestamp"),
resultSet.getString("push_code"));
} }
} }

View File

@ -31,14 +31,18 @@ public class Hex {
}; };
public static String toString(byte[] bytes) { public static String toString(byte[] bytes) {
return toString(bytes, 0, bytes.length); return toString(bytes, 0, bytes.length, false);
} }
public static String toString(byte[] bytes, int offset, int length) { public static String toStringCondensed(byte[] bytes) {
return toString(bytes, 0, bytes.length, true);
}
public static String toString(byte[] bytes, int offset, int length, boolean condensed) {
StringBuffer buf = new StringBuffer(); StringBuffer buf = new StringBuffer();
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
appendHexChar(buf, bytes[offset + i]); appendHexChar(buf, bytes[offset + i]);
buf.append(' '); if (!condensed) buf.append(' ');
} }
return buf.toString(); return buf.toString();
} }

View File

@ -170,6 +170,10 @@ public class Util {
return Arrays.hashCode(objects); return Arrays.hashCode(objects);
} }
public static boolean isEquals(Object first, Object second) {
return (first == null && second == null) || (first == second) || (first != null && first.equals(second));
}
public static long todayInMillis() { public static long todayInMillis() {
return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(System.currentTimeMillis())); return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(System.currentTimeMillis()));
} }

View File

@ -186,4 +186,15 @@
</column> </column>
</addColumn> </addColumn>
</changeSet> </changeSet>
<changeSet id="6" author="moxie">
<addColumn tableName="pending_accounts">
<column name="push_code" type="text">
<constraints nullable="true"/>
</column>
</addColumn>
<dropNotNullConstraint tableName="pending_accounts" columnName="verification_code"/>
</changeSet>
</databaseChangeLog> </databaseChangeLog>

View File

@ -1,10 +1,12 @@
package org.whispersystems.textsecuregcm.tests.controllers; package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import net.sourceforge.argparse4j.inf.Argument;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
@ -19,6 +21,10 @@ import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.providers.TimeProvider; import org.whispersystems.textsecuregcm.providers.TimeProvider;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.push.GCMSender;
import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient; import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
@ -53,6 +59,7 @@ public class AccountControllerTest {
private static final String SENDER_PIN = "+14153333333"; private static final String SENDER_PIN = "+14153333333";
private static final String SENDER_OVER_PIN = "+14154444444"; private static final String SENDER_OVER_PIN = "+14154444444";
private static final String SENDER_OVER_PREFIX = "+14156666666"; private static final String SENDER_OVER_PREFIX = "+14156666666";
private static final String SENDER_PREAUTH = "+14157777777";
private static final String ABUSIVE_HOST = "192.168.1.1"; private static final String ABUSIVE_HOST = "192.168.1.1";
private static final String RESTRICTED_HOST = "192.168.1.2"; private static final String RESTRICTED_HOST = "192.168.1.2";
@ -80,6 +87,8 @@ public class AccountControllerTest {
private TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class); private TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class);
private Account senderPinAccount = mock(Account.class); private Account senderPinAccount = mock(Account.class);
private RecaptchaClient recaptchaClient = mock(RecaptchaClient.class); private RecaptchaClient recaptchaClient = mock(RecaptchaClient.class);
private GCMSender gcmSender = mock(GCMSender.class);
private APNSender apnSender = mock(APNSender.class);
@Rule @Rule
public final ResourceTestRule resources = ResourceTestRule.builder() public final ResourceTestRule resources = ResourceTestRule.builder()
@ -97,7 +106,9 @@ public class AccountControllerTest {
storedMessages, storedMessages,
turnTokenGenerator, turnTokenGenerator,
new HashMap<>(), new HashMap<>(),
recaptchaClient)) recaptchaClient,
gcmSender,
apnSender))
.build(); .build();
@ -116,15 +127,17 @@ public class AccountControllerTest {
when(senderPinAccount.getPin()).thenReturn(Optional.of("31337")); when(senderPinAccount.getPin()).thenReturn(Optional.of("31337"));
when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis()); when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis());
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis()))); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_OLD)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(31)))); when(pendingAccountsManager.getCodeForNumber(SENDER_OLD)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(31), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("333333", System.currentTimeMillis()))); when(pendingAccountsManager.getCodeForNumber(SENDER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("333333", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_OVER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("444444", System.currentTimeMillis()))); when(pendingAccountsManager.getCodeForNumber(SENDER_OVER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("444444", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_PREAUTH)).thenReturn(Optional.of(new StoredVerificationCode("555555", System.currentTimeMillis(), "validchallenge")));
when(accountsManager.get(eq(SENDER_PIN))).thenReturn(Optional.of(senderPinAccount)); when(accountsManager.get(eq(SENDER_PIN))).thenReturn(Optional.of(senderPinAccount));
when(accountsManager.get(eq(SENDER_OVER_PIN))).thenReturn(Optional.of(senderPinAccount)); when(accountsManager.get(eq(SENDER_OVER_PIN))).thenReturn(Optional.of(senderPinAccount));
when(accountsManager.get(eq(SENDER))).thenReturn(Optional.empty()); when(accountsManager.get(eq(SENDER))).thenReturn(Optional.empty());
when(accountsManager.get(eq(SENDER_OLD))).thenReturn(Optional.empty()); when(accountsManager.get(eq(SENDER_OLD))).thenReturn(Optional.empty());
when(accountsManager.get(eq(SENDER_PREAUTH))).thenReturn(Optional.empty());
when(abusiveHostRules.getAbusiveHostRulesFor(eq(ABUSIVE_HOST))).thenReturn(Collections.singletonList(new AbusiveHostRule(ABUSIVE_HOST, true, Collections.emptyList()))); when(abusiveHostRules.getAbusiveHostRulesFor(eq(ABUSIVE_HOST))).thenReturn(Collections.singletonList(new AbusiveHostRule(ABUSIVE_HOST, true, Collections.emptyList())));
when(abusiveHostRules.getAbusiveHostRulesFor(eq(RESTRICTED_HOST))).thenReturn(Collections.singletonList(new AbusiveHostRule(RESTRICTED_HOST, false, Collections.singletonList("+123")))); when(abusiveHostRules.getAbusiveHostRulesFor(eq(RESTRICTED_HOST))).thenReturn(Collections.singletonList(new AbusiveHostRule(RESTRICTED_HOST, false, Collections.singletonList("+123"))));
@ -143,6 +156,45 @@ public class AccountControllerTest {
doThrow(new RateLimitExceededException(RATE_LIMITED_HOST2)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2); doThrow(new RateLimitExceededException(RATE_LIMITED_HOST2)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
} }
@Test
public void testGetFcmPreauth() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/accounts/fcm/preauth/mytoken/+14152222222")
.request()
.get();
assertThat(response.getStatus()).isEqualTo(200);
ArgumentCaptor<GcmMessage> captor = ArgumentCaptor.forClass(GcmMessage.class);
verify(gcmSender, times(1)).sendMessage(captor.capture());
assertThat(captor.getValue().getGcmId()).isEqualTo("mytoken");
assertThat(captor.getValue().getData().isPresent()).isTrue();
assertThat(captor.getValue().getData().get().length()).isEqualTo(32);
verifyNoMoreInteractions(apnSender);
}
@Test
public void testGetApnPreauth() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/accounts/apn/preauth/mytoken/+14152222222")
.request()
.get();
assertThat(response.getStatus()).isEqualTo(200);
ArgumentCaptor<ApnMessage> captor = ArgumentCaptor.forClass(ApnMessage.class);
verify(apnSender, times(1)).sendMessage(captor.capture());
assertThat(captor.getValue().getApnId()).isEqualTo("mytoken");
assertThat(captor.getValue().getChallengeData().isPresent()).isTrue();
assertThat(captor.getValue().getChallengeData().get().length()).isEqualTo(32);
assertThat(captor.getValue().getMessage()).contains("\"challenge\" : \"" + captor.getValue().getChallengeData().get() + "\"");
verifyNoMoreInteractions(gcmSender);
}
@Test @Test
public void testSendCode() throws Exception { public void testSendCode() throws Exception {
Response response = Response response =
@ -158,6 +210,54 @@ public class AccountControllerTest {
verify(abusiveHostRules).getAbusiveHostRulesFor(eq(NICE_HOST)); verify(abusiveHostRules).getAbusiveHostRulesFor(eq(NICE_HOST));
} }
@Test
public void testSendCodeWithValidPreauth() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", SENDER_PREAUTH))
.queryParam("challenge", "validchallenge")
.request()
.header("X-Forwarded-For", NICE_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
verify(smsSender).deliverSmsVerification(eq(SENDER_PREAUTH), eq(Optional.empty()), anyString());
verify(abusiveHostRules).getAbusiveHostRulesFor(eq(NICE_HOST));
}
@Test
public void testSendCodeWithInvalidPreauth() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", SENDER_PREAUTH))
.queryParam("challenge", "invalidchallenge")
.request()
.header("X-Forwarded-For", NICE_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(402);
verifyNoMoreInteractions(smsSender);
verifyNoMoreInteractions(abusiveHostRules);
}
@Test
public void testSendCodeWithNoPreauth() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", SENDER_PREAUTH))
.request()
.header("X-Forwarded-For", NICE_HOST)
.get();
assertThat(response.getStatus()).isEqualTo(200);
verify(smsSender).deliverSmsVerification(eq(SENDER_PREAUTH), eq(Optional.empty()), anyString());
verify(abusiveHostRules).getAbusiveHostRulesFor(eq(NICE_HOST));
}
@Test @Test
public void testSendiOSCode() throws Exception { public void testSendiOSCode() throws Exception {
Response response = Response response =

View File

@ -117,8 +117,8 @@ public class DeviceControllerTest {
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(masterDevice)); when(account.getAuthenticatedDevice()).thenReturn(Optional.of(masterDevice));
when(account.isEnabled()).thenReturn(false); when(account.isEnabled()).thenReturn(false);
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis()))); when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null)));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(new StoredVerificationCode("1112223", System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(31)))); when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(new StoredVerificationCode("1112223", System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(31), null)));
when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account)); when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount)); when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));
} }

View File

@ -63,7 +63,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response)); .thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -75,7 +75,7 @@ public class APNSenderTest {
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(notification.getValue().getTopic()).isEqualTo("foo.voip"); assertThat(notification.getValue().getTopic()).isEqualTo("foo.voip");
@ -97,7 +97,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response)); .thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -109,7 +109,7 @@ public class APNSenderTest {
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(notification.getValue().getTopic()).isEqualTo("foo"); assertThat(notification.getValue().getTopic()).isEqualTo("foo");
@ -133,7 +133,7 @@ public class APNSenderTest {
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -150,7 +150,7 @@ public class APNSenderTest {
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER);
@ -236,7 +236,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response)); .thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -253,7 +253,7 @@ public class APNSenderTest {
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER);
@ -331,7 +331,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response)); .thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -343,7 +343,7 @@ public class APNSenderTest {
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.GENERIC_FAILURE); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.GENERIC_FAILURE);
@ -364,7 +364,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), new Exception("lost connection"))); .thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(executor, invocationOnMock.getArgument(0), new Exception("lost connection")));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -384,7 +384,7 @@ public class APNSenderTest {
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
verifyNoMoreInteractions(apnsClient); verifyNoMoreInteractions(apnsClient);

View File

@ -32,7 +32,7 @@ public class GCMSenderTest {
when(successResult.hasCanonicalRegistrationId()).thenReturn(false); when(successResult.hasCanonicalRegistrationId()).thenReturn(false);
when(successResult.isSuccess()).thenReturn(true); when(successResult.isSuccess()).thenReturn(true);
GcmMessage message = new GcmMessage("foo", "+12223334444", 1, false); GcmMessage message = new GcmMessage("foo", "+12223334444", 1, GcmMessage.Type.NOTIFICATION, Optional.empty());
GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService); GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService);
CompletableFuture<Result> successFuture = CompletableFuture.completedFuture(successResult); CompletableFuture<Result> successFuture = CompletableFuture.completedFuture(successResult);
@ -66,7 +66,7 @@ public class GCMSenderTest {
when(invalidResult.hasCanonicalRegistrationId()).thenReturn(false); when(invalidResult.hasCanonicalRegistrationId()).thenReturn(false);
when(invalidResult.isSuccess()).thenReturn(true); when(invalidResult.isSuccess()).thenReturn(true);
GcmMessage message = new GcmMessage(gcmId, destinationNumber, 1, false); GcmMessage message = new GcmMessage(gcmId, destinationNumber, 1, GcmMessage.Type.NOTIFICATION, Optional.empty());
GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService); GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService);
CompletableFuture<Result> invalidFuture = CompletableFuture.completedFuture(invalidResult); CompletableFuture<Result> invalidFuture = CompletableFuture.completedFuture(invalidResult);
@ -105,7 +105,7 @@ public class GCMSenderTest {
when(canonicalResult.isSuccess()).thenReturn(false); when(canonicalResult.isSuccess()).thenReturn(false);
when(canonicalResult.getCanonicalRegistrationId()).thenReturn(canonicalId); when(canonicalResult.getCanonicalRegistrationId()).thenReturn(canonicalId);
GcmMessage message = new GcmMessage(gcmId, destinationNumber, 1, false); GcmMessage message = new GcmMessage(gcmId, destinationNumber, 1, GcmMessage.Type.NOTIFICATION, Optional.empty());
GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService); GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService);
CompletableFuture<Result> invalidFuture = CompletableFuture.completedFuture(canonicalResult); CompletableFuture<Result> invalidFuture = CompletableFuture.completedFuture(canonicalResult);

View File

@ -33,7 +33,7 @@ public class PendingAccountsTest {
@Test @Test
public void testStore() throws SQLException { public void testStore() throws SQLException {
pendingAccounts.insert("+14151112222", "1234", 1111); pendingAccounts.insert("+14151112222", "1234", 1111, null);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM pending_accounts WHERE number = ?"); PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM pending_accounts WHERE number = ?");
statement.setString(1, "+14151112222"); statement.setString(1, "+14151112222");
@ -43,6 +43,27 @@ public class PendingAccountsTest {
if (resultSet.next()) { if (resultSet.next()) {
assertThat(resultSet.getString("verification_code")).isEqualTo("1234"); assertThat(resultSet.getString("verification_code")).isEqualTo("1234");
assertThat(resultSet.getLong("timestamp")).isEqualTo(1111); assertThat(resultSet.getLong("timestamp")).isEqualTo(1111);
assertThat(resultSet.getString("push_code")).isNull();
} else {
throw new AssertionError("no results");
}
assertThat(resultSet.next()).isFalse();
}
@Test
public void testStoreWithPushChallenge() throws SQLException {
pendingAccounts.insert("+14151112222", null, 1111, "112233");
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM pending_accounts WHERE number = ?");
statement.setString(1, "+14151112222");
ResultSet resultSet = statement.executeQuery();
if (resultSet.next()) {
assertThat(resultSet.getString("verification_code")).isNull();
assertThat(resultSet.getLong("timestamp")).isEqualTo(1111);
assertThat(resultSet.getString("push_code")).isEqualTo("112233");
} else { } else {
throw new AssertionError("no results"); throw new AssertionError("no results");
} }
@ -52,8 +73,8 @@ public class PendingAccountsTest {
@Test @Test
public void testRetrieve() throws Exception { public void testRetrieve() throws Exception {
pendingAccounts.insert("+14151112222", "4321", 2222); pendingAccounts.insert("+14151112222", "4321", 2222, null);
pendingAccounts.insert("+14151113333", "1212", 5555); pendingAccounts.insert("+14151113333", "1212", 5555, null);
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
@ -65,10 +86,26 @@ public class PendingAccountsTest {
assertThat(missingCode.isPresent()).isFalse(); assertThat(missingCode.isPresent()).isFalse();
} }
@Test
public void testRetrieveWithPushChallenge() throws Exception {
pendingAccounts.insert("+14151112222", "4321", 2222, "bar");
pendingAccounts.insert("+14151113333", "1212", 5555, "bang");
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4321");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(2222);
assertThat(verificationCode.get().getPushCode()).isEqualTo("bar");
Optional<StoredVerificationCode> missingCode = pendingAccounts.getCodeForNumber("+11111111111");
assertThat(missingCode.isPresent()).isFalse();
}
@Test @Test
public void testOverwrite() throws Exception { public void testOverwrite() throws Exception {
pendingAccounts.insert("+14151112222", "4321", 2222); pendingAccounts.insert("+14151112222", "4321", 2222, null);
pendingAccounts.insert("+14151112222", "4444", 3333); pendingAccounts.insert("+14151112222", "4444", 3333, null);
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
@ -77,10 +114,24 @@ public class PendingAccountsTest {
assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333); assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333);
} }
@Test
public void testOverwriteWithPushToken() throws Exception {
pendingAccounts.insert("+14151112222", "4321", 2222, "bar");
pendingAccounts.insert("+14151112222", "4444", 3333, "bang");
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("4444");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(3333);
assertThat(verificationCode.get().getPushCode()).isEqualTo("bang");
}
@Test @Test
public void testVacuum() { public void testVacuum() {
pendingAccounts.insert("+14151112222", "4321", 2222); pendingAccounts.insert("+14151112222", "4321", 2222, null);
pendingAccounts.insert("+14151112222", "4444", 3333); pendingAccounts.insert("+14151112222", "4444", 3333, null);
pendingAccounts.vacuum(); pendingAccounts.vacuum();
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
@ -92,8 +143,8 @@ public class PendingAccountsTest {
@Test @Test
public void testRemove() { public void testRemove() {
pendingAccounts.insert("+14151112222", "4321", 2222); pendingAccounts.insert("+14151112222", "4321", 2222, "bar");
pendingAccounts.insert("+14151113333", "1212", 5555); pendingAccounts.insert("+14151113333", "1212", 5555, null);
Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222"); Optional<StoredVerificationCode> verificationCode = pendingAccounts.getCodeForNumber("+14151112222");
@ -110,6 +161,7 @@ public class PendingAccountsTest {
assertThat(verificationCode.isPresent()).isTrue(); assertThat(verificationCode.isPresent()).isTrue();
assertThat(verificationCode.get().getCode()).isEqualTo("1212"); assertThat(verificationCode.get().getCode()).isEqualTo("1212");
assertThat(verificationCode.get().getTimestamp()).isEqualTo(5555); assertThat(verificationCode.get().getTimestamp()).isEqualTo(5555);
assertThat(verificationCode.get().getPushCode()).isNull();
} }