Support for UUID based addressing

This commit is contained in:
Moxie Marlinspike 2019-06-20 19:25:15 -07:00
parent 0f8cb7ea6d
commit 7a3a385569
51 changed files with 1379 additions and 695 deletions

View File

@ -155,6 +155,14 @@
<version>0.13.1</version> <version>0.13.1</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.fasterxml.uuid</groupId>
<artifactId>java-uuid-generator</artifactId>
<version>3.2.0</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>

View File

@ -31,6 +31,7 @@ message Envelope {
optional Type type = 1; optional Type type = 1;
optional string source = 2; optional string source = 2;
optional string sourceUuid = 11;
optional uint32 sourceDevice = 7; optional uint32 sourceDevice = 7;
optional string relay = 3; optional string relay = 3;
optional uint64 timestamp = 5; optional uint64 timestamp = 5;
@ -57,6 +58,7 @@ message ServerCertificate {
message SenderCertificate { message SenderCertificate {
message Certificate { message Certificate {
optional string sender = 1; optional string sender = 1;
optional string senderUuid = 6;
optional uint32 senderDevice = 2; optional uint32 senderDevice = 2;
optional fixed64 expires = 3; optional fixed64 expires = 3;
optional bytes identityKey = 4; optional bytes identityKey = 4;

View File

@ -45,7 +45,6 @@ import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.ProfileController; import org.whispersystems.textsecuregcm.controllers.ProfileController;
import org.whispersystems.textsecuregcm.controllers.ProvisioningController; import org.whispersystems.textsecuregcm.controllers.ProvisioningController;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.TransparentDataController;
import org.whispersystems.textsecuregcm.controllers.StickerController; import org.whispersystems.textsecuregcm.controllers.StickerController;
import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController; import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -261,7 +260,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new ProvisioningController(rateLimiters, pushSender)); environment.jersey().register(new ProvisioningController(rateLimiters, pushSender));
environment.jersey().register(new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()))); environment.jersey().register(new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays())));
environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), config.getVoiceVerificationConfiguration().getLocales())); environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), config.getVoiceVerificationConfiguration().getLocales()));
environment.jersey().register(new TransparentDataController(accountsManager, config.getTransparentDataIndex()));
environment.jersey().register(new SecureStorageController(storageCredentialsGenerator)); environment.jersey().register(new SecureStorageController(storageCredentialsGenerator));
environment.jersey().register(attachmentControllerV1); environment.jersey().register(attachmentControllerV1);
environment.jersey().register(attachmentControllerV2); environment.jersey().register(attachmentControllerV2);

View File

@ -0,0 +1,35 @@
package org.whispersystems.textsecuregcm.auth;
import java.util.UUID;
public class AmbiguousIdentifier {
private final UUID uuid;
private final String number;
public AmbiguousIdentifier(String target) {
if (target.startsWith("+")) {
this.uuid = null;
this.number = target;
} else {
this.uuid = UUID.fromString(target);
this.number = null;
}
}
public UUID getUuid() {
return uuid;
}
public String getNumber() {
return number;
}
public boolean hasUuid() {
return uuid != null;
}
public boolean hasNumber() {
return number != null;
}
}

View File

@ -24,20 +24,20 @@ import java.io.IOException;
public class AuthorizationHeader { public class AuthorizationHeader {
private final String number; private final AmbiguousIdentifier identifier;
private final long accountId; private final long deviceId;
private final String password; private final String password;
private AuthorizationHeader(String number, long accountId, String password) { private AuthorizationHeader(AmbiguousIdentifier identifier, long deviceId, String password) {
this.number = number; this.identifier = identifier;
this.accountId = accountId; this.deviceId = deviceId;
this.password = password; this.password = password;
} }
public static AuthorizationHeader fromUserAndPassword(String user, String password) throws InvalidAuthorizationHeaderException { public static AuthorizationHeader fromUserAndPassword(String user, String password) throws InvalidAuthorizationHeaderException {
try { try {
String[] numberAndId = user.split("\\."); String[] numberAndId = user.split("\\.");
return new AuthorizationHeader(numberAndId[0], return new AuthorizationHeader(new AmbiguousIdentifier(numberAndId[0]),
numberAndId.length > 1 ? Long.parseLong(numberAndId[1]) : 1, numberAndId.length > 1 ? Long.parseLong(numberAndId[1]) : 1,
password); password);
} catch (NumberFormatException nfe) { } catch (NumberFormatException nfe) {
@ -79,12 +79,12 @@ public class AuthorizationHeader {
} }
} }
public String getNumber() { public AmbiguousIdentifier getIdentifier() {
return number; return identifier;
} }
public long getDeviceId() { public long getDeviceId() {
return accountId; return deviceId;
} }
public String getPassword() { public String getPassword() {

View File

@ -12,6 +12,7 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.auth.basic.BasicCredentials;
@ -38,7 +39,7 @@ public class BaseAccountAuthenticator {
public Optional<Account> authenticate(BasicCredentials basicCredentials, boolean enabledRequired) { public Optional<Account> authenticate(BasicCredentials basicCredentials, boolean enabledRequired) {
try { try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword()); AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword());
Optional<Account> account = accountsManager.get(authorizationHeader.getNumber()); Optional<Account> account = accountsManager.get(authorizationHeader.getIdentifier());
if (!account.isPresent()) { if (!account.isPresent()) {
noSuchAccountMeter.mark(); noSuchAccountMeter.mark();
@ -73,7 +74,7 @@ public class BaseAccountAuthenticator {
authenticationFailedMeter.mark(); authenticationFailedMeter.mark();
return Optional.empty(); return Optional.empty();
} catch (InvalidAuthorizationHeaderException iahe) { } catch (IllegalArgumentException | InvalidAuthorizationHeaderException iae) {
invalidAuthHeaderMeter.mark(); invalidAuthHeaderMeter.mark();
return Optional.empty(); return Optional.empty();
} }

View File

@ -31,6 +31,7 @@ public class CertificateGenerator {
public byte[] createFor(Account account, Device device) throws IOException, InvalidKeyException { public byte[] createFor(Account account, Device device) throws IOException, InvalidKeyException {
byte[] certificate = SenderCertificate.Certificate.newBuilder() byte[] certificate = SenderCertificate.Certificate.newBuilder()
.setSender(account.getNumber()) .setSender(account.getNumber())
.setSenderUuid(account.getUuid().toString())
.setSenderDevice(Math.toIntExact(device.getId())) .setSenderDevice(Math.toIntExact(device.getId()))
.setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays)) .setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays))
.setIdentityKey(ByteString.copyFrom(Base64.decode(account.getIdentityKey()))) .setIdentityKey(ByteString.copyFrom(Base64.decode(account.getIdentityKey())))

View File

@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.auth.TurnToken; import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.AccountCreationResult;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceName; import org.whispersystems.textsecuregcm.entities.DeviceName;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
@ -78,6 +79,7 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
@ -245,7 +247,7 @@ public class AccountController {
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/code/{verification_code}") @Path("/code/{verification_code}")
public void verifyAccount(@PathParam("verification_code") String verificationCode, public AccountCreationResult verifyAccount(@PathParam("verification_code") String verificationCode,
@HeaderParam("Authorization") String authorizationHeader, @HeaderParam("Authorization") String authorizationHeader,
@HeaderParam("X-Signal-Agent") String userAgent, @HeaderParam("X-Signal-Agent") String userAgent,
@Valid AccountAttributes accountAttributes) @Valid AccountAttributes accountAttributes)
@ -253,9 +255,13 @@ public class AccountController {
{ {
try { try {
AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader); AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader);
String number = header.getNumber(); String number = header.getIdentifier().getNumber();
String password = header.getPassword(); String password = header.getPassword();
if (number == null) {
throw new WebApplicationException(400);
}
rateLimiters.getVerifyLimiter().validate(number); rateLimiters.getVerifyLimiter().validate(number);
Optional<StoredVerificationCode> storedVerificationCode = pendingAccounts.getCodeForNumber(number); Optional<StoredVerificationCode> storedVerificationCode = pendingAccounts.getCodeForNumber(number);
@ -308,9 +314,11 @@ public class AccountController {
rateLimiters.getPinLimiter().clear(number); rateLimiters.getPinLimiter().clear(number);
} }
createAccount(number, password, userAgent, accountAttributes); Account account = createAccount(number, password, userAgent, accountAttributes);
metricRegistry.meter(name(AccountController.class, "verify", Util.getCountryCode(number))).mark(); metricRegistry.meter(name(AccountController.class, "verify", Util.getCountryCode(number))).mark();
return new AccountCreationResult(account.getUuid());
} catch (InvalidAuthorizationHeaderException e) { } catch (InvalidAuthorizationHeaderException e) {
logger.info("Bad Authorization Header", e); logger.info("Bad Authorization Header", e);
throw new WebApplicationException(Response.status(401).build()); throw new WebApplicationException(Response.status(401).build());
@ -502,6 +510,13 @@ public class AccountController {
accounts.update(account); accounts.update(account);
} }
@GET
@Path("/whoami")
@Produces(MediaType.APPLICATION_JSON)
public AccountCreationResult whoAmI(@Auth Account account) {
return new AccountCreationResult(account.getUuid());
}
private CaptchaRequirement requiresCaptcha(String number, String transport, String forwardedFor, private CaptchaRequirement requiresCaptcha(String number, String transport, String forwardedFor,
String requester, String requester,
Optional<String> captchaToken, Optional<String> captchaToken,
@ -576,7 +591,7 @@ public class AccountController {
return false; return false;
} }
private void createAccount(String number, String password, String userAgent, AccountAttributes accountAttributes) { private Account createAccount(String number, String password, String userAgent, AccountAttributes accountAttributes) {
Device device = new Device(); Device device = new Device();
device.setId(Device.MASTER_ID); device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password)); device.setAuthenticationCredentials(new AuthenticationCredentials(password));
@ -591,6 +606,7 @@ public class AccountController {
Account account = new Account(); Account account = new Account();
account.setNumber(number); account.setNumber(number);
account.setUuid(UUID.randomUUID());
account.addDevice(device); account.addDevice(device);
account.setPin(accountAttributes.getPin()); account.setPin(accountAttributes.getPin());
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey()); account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
@ -608,6 +624,8 @@ public class AccountController {
messagesManager.clear(number); messagesManager.clear(number);
pendingAccounts.remove(number); pendingAccounts.remove(number);
return account;
} }
@VisibleForTesting protected @VisibleForTesting protected

View File

@ -164,9 +164,11 @@ public class DeviceController {
{ {
try { try {
AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader); AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader);
String number = header.getNumber(); String number = header.getIdentifier().getNumber();
String password = header.getPassword(); String password = header.getPassword();
if (number == null) throw new WebApplicationException(400);
rateLimiters.getVerifyDeviceLimiter().validate(number); rateLimiters.getVerifyDeviceLimiter().validate(number);
Optional<StoredVerificationCode> storedVerificationCode = pendingDevices.getCodeForNumber(number); Optional<StoredVerificationCode> storedVerificationCode = pendingDevices.getCodeForNumber(number);

View File

@ -19,6 +19,7 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
@ -115,11 +116,11 @@ public class KeysController {
@Timed @Timed
@GET @GET
@Path("/{number}/{device_id}") @Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponse> getDeviceKeys(@Auth Optional<Account> account, public Optional<PreKeyResponse> getDeviceKeys(@Auth Optional<Account> account,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("number") String number, @PathParam("identifier") AmbiguousIdentifier targetName,
@PathParam("device_id") String deviceId) @PathParam("device_id") String deviceId)
throws RateLimitExceededException throws RateLimitExceededException
{ {
@ -127,13 +128,13 @@ public class KeysController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
Optional<Account> target = accounts.get(number); Optional<Account> target = accounts.get(targetName);
OptionalAccess.verify(account, accessKey, target, deviceId); OptionalAccess.verify(account, accessKey, target, deviceId);
assert(target.isPresent()); assert(target.isPresent());
if (account.isPresent()) { if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + number + "." + deviceId); rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "__" + target.get().getNumber() + "." + deviceId);
} }
List<KeyRecord> targetKeys = getLocalKeys(target.get(), deviceId); List<KeyRecord> targetKeys = getLocalKeys(target.get(), deviceId);

View File

@ -23,6 +23,7 @@ import com.codahale.metrics.annotation.Timed;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
@ -109,7 +110,7 @@ public class MessageController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public SendMessageResponse sendMessage(@Auth Optional<Account> source, public SendMessageResponse sendMessage(@Auth Optional<Account> source,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("destination") String destinationName, @PathParam("destination") AmbiguousIdentifier destinationName,
@Valid IncomingMessageList messages) @Valid IncomingMessageList messages)
throws RateLimitExceededException throws RateLimitExceededException
{ {
@ -117,18 +118,18 @@ public class MessageController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
if (source.isPresent() && !source.get().getNumber().equals(destinationName)) { if (source.isPresent() && !source.get().isFor(destinationName)) {
rateLimiters.getMessagesLimiter().validate(source.get().getNumber() + "__" + destinationName); rateLimiters.getMessagesLimiter().validate(source.get().getNumber() + "__" + destinationName);
} }
if (source.isPresent() && !source.get().getNumber().equals(destinationName)) { if (source.isPresent() && !source.get().isFor(destinationName)) {
identifiedMeter.mark(); identifiedMeter.mark();
} else { } else if (!source.isPresent()) {
unidentifiedMeter.mark(); unidentifiedMeter.mark();
} }
try { try {
boolean isSyncMessage = source.isPresent() && source.get().getNumber().equals(destinationName); boolean isSyncMessage = source.isPresent() && source.get().isFor(destinationName);
Optional<Account> destination; Optional<Account> destination;
@ -246,6 +247,7 @@ public class MessageController {
if (source.isPresent()) { if (source.isPresent()) {
messageBuilder.setSource(source.get().getNumber()) messageBuilder.setSource(source.get().getNumber())
.setSourceUuid(source.get().getUuid().toString())
.setSourceDevice((int)source.get().getAuthenticatedDevice().get().getId()); .setSourceDevice((int)source.get().getAuthenticatedDevice().get().getId());
} }

View File

@ -10,6 +10,7 @@ import com.codahale.metrics.annotation.Timed;
import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Base64;
import org.hibernate.validator.constraints.Length; import org.hibernate.validator.constraints.Length;
import org.hibernate.validator.valuehandling.UnwrapValidatedValue; import org.hibernate.validator.valuehandling.UnwrapValidatedValue;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum;
@ -79,10 +80,10 @@ public class ProfileController {
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{number}") @Path("/{identifier}")
public Profile getProfile(@Auth Optional<Account> requestAccount, public Profile getProfile(@Auth Optional<Account> requestAccount,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("number") String number, @PathParam("identifier") AmbiguousIdentifier identifier,
@QueryParam("ca") boolean useCaCertificate) @QueryParam("ca") boolean useCaCertificate)
throws RateLimitExceededException throws RateLimitExceededException
{ {
@ -94,7 +95,7 @@ public class ProfileController {
rateLimiters.getProfileLimiter().validate(requestAccount.get().getNumber()); rateLimiters.getProfileLimiter().validate(requestAccount.get().getNumber());
} }
Optional<Account> accountProfile = accountsManager.get(number); Optional<Account> accountProfile = accountsManager.get(identifier);
OptionalAccess.verify(requestAccount, accessKey, accountProfile); OptionalAccess.verify(requestAccount, accessKey, accountProfile);
//noinspection ConstantConditions,OptionalGetWithoutIsPresent //noinspection ConstantConditions,OptionalGetWithoutIsPresent

View File

@ -1,42 +0,0 @@
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PublicAccount;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.util.Map;
import java.util.Optional;
@Path("/v1/transparency/")
public class TransparentDataController {
private final AccountsManager accountsManager;
private final Map<String, String> transparentDataIndex;
public TransparentDataController(AccountsManager accountsManager,
Map<String, String> transparentDataIndex)
{
this.accountsManager = accountsManager;
this.transparentDataIndex = transparentDataIndex;
}
@GET
@Path("/account/{id}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PublicAccount> getAccount(@PathParam("id") String id) {
String index = transparentDataIndex.get(id);
if (index != null) {
return accountsManager.get(index).map(PublicAccount::new);
}
return Optional.empty();
}
}

View File

@ -0,0 +1,21 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.UUID;
public class AccountCreationResult {
@JsonProperty
private UUID uuid;
public AccountCreationResult() {}
public AccountCreationResult(UUID uuid) {
this.uuid = uuid;
}
public UUID getUuid() {
return uuid;
}
}

View File

@ -19,10 +19,11 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Map; import java.util.Map;
import java.util.HashMap; import java.util.HashMap;
import java.util.UUID;
public class ActiveUserTally { public class ActiveUserTally {
@JsonProperty @JsonProperty
private String fromNumber; private UUID fromUuid;
@JsonProperty @JsonProperty
private Map<String, long[]> platforms; private Map<String, long[]> platforms;
@ -32,14 +33,14 @@ public class ActiveUserTally {
public ActiveUserTally() {} public ActiveUserTally() {}
public ActiveUserTally(String fromNumber, Map<String, long[]> platforms, Map<String, long[]> countries) { public ActiveUserTally(UUID fromUuid, Map<String, long[]> platforms, Map<String, long[]> countries) {
this.fromNumber = fromNumber; this.fromUuid = fromUuid;
this.platforms = platforms; this.platforms = platforms;
this.countries = countries; this.countries = countries;
} }
public String getFromNumber() { public UUID getFromUuid() {
return this.fromNumber; return this.fromUuid;
} }
public Map<String, long[]> getPlatforms() { public Map<String, long[]> getPlatforms() {
@ -50,8 +51,8 @@ public class ActiveUserTally {
return this.countries; return this.countries;
} }
public void setFromNumber(String fromNumber) { public void setFromUuid(UUID fromUuid) {
this.fromNumber = fromNumber; this.fromUuid = fromUuid;
} }
} }

View File

@ -19,14 +19,15 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List; import java.util.List;
import java.util.UUID;
public class DirectoryReconciliationRequest { public class DirectoryReconciliationRequest {
@JsonProperty @JsonProperty
private String fromNumber; private UUID fromUuid;
@JsonProperty @JsonProperty
private String toNumber; private UUID toUuid;
@JsonProperty @JsonProperty
private List<String> numbers; private List<String> numbers;
@ -34,18 +35,18 @@ public class DirectoryReconciliationRequest {
public DirectoryReconciliationRequest() { public DirectoryReconciliationRequest() {
} }
public DirectoryReconciliationRequest(String fromNumber, String toNumber, List<String> numbers) { public DirectoryReconciliationRequest(UUID fromUuid, UUID toUuid, List<String> numbers) {
this.fromNumber = fromNumber; this.fromUuid = fromUuid;
this.toNumber = toNumber; this.toUuid = toUuid;
this.numbers = numbers; this.numbers = numbers;
} }
public String getFromNumber() { public UUID getFromUuid() {
return fromNumber; return fromUuid;
} }
public String getToNumber() { public UUID getToUuid() {
return toNumber; return toUuid;
} }
public List<String> getNumbers() { public List<String> getNumbers() {

View File

@ -28,6 +28,9 @@ public class OutgoingMessageEntity {
@JsonProperty @JsonProperty
private String source; private String source;
@JsonProperty
private UUID sourceUuid;
@JsonProperty @JsonProperty
private int sourceDevice; private int sourceDevice;
@ -44,8 +47,8 @@ public class OutgoingMessageEntity {
public OutgoingMessageEntity(long id, boolean cached, public OutgoingMessageEntity(long id, boolean cached,
UUID guid, int type, String relay, long timestamp, UUID guid, int type, String relay, long timestamp,
String source, int sourceDevice, byte[] message, String source, UUID sourceUuid, int sourceDevice,
byte[] content, long serverTimestamp) byte[] message, byte[] content, long serverTimestamp)
{ {
this.id = id; this.id = id;
this.cached = cached; this.cached = cached;
@ -54,6 +57,7 @@ public class OutgoingMessageEntity {
this.relay = relay; this.relay = relay;
this.timestamp = timestamp; this.timestamp = timestamp;
this.source = source; this.source = source;
this.sourceUuid = sourceUuid;
this.sourceDevice = sourceDevice; this.sourceDevice = sourceDevice;
this.message = message; this.message = message;
this.content = content; this.content = content;
@ -80,6 +84,10 @@ public class OutgoingMessageEntity {
return source; return source;
} }
public UUID getSourceUuid() {
return sourceUuid;
}
public int getSourceDevice() { public int getSourceDevice() {
return sourceDevice; return sourceDevice;
} }

View File

@ -21,11 +21,14 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import java.security.Principal; import java.security.Principal;
import java.util.HashSet; import java.util.HashSet;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public class Account implements Principal { public class Account implements Principal {
@ -33,6 +36,9 @@ public class Account implements Principal {
static final int MEMCACHE_VERION = 5; static final int MEMCACHE_VERION = 5;
@JsonIgnore @JsonIgnore
private UUID uuid;
@JsonProperty
private String number; private String number;
@JsonProperty @JsonProperty
@ -71,8 +77,9 @@ public class Account implements Principal {
public Account() {} public Account() {}
@VisibleForTesting @VisibleForTesting
public Account(String number, Set<Device> devices, byte[] unidentifiedAccessKey) { public Account(String number, UUID uuid, Set<Device> devices, byte[] unidentifiedAccessKey) {
this.number = number; this.number = number;
this.uuid = uuid;
this.devices = devices; this.devices = devices;
this.unidentifiedAccessKey = unidentifiedAccessKey; this.unidentifiedAccessKey = unidentifiedAccessKey;
} }
@ -85,6 +92,14 @@ public class Account implements Principal {
this.authenticatedDevice = device; this.authenticatedDevice = device;
} }
public UUID getUuid() {
return uuid;
}
public void setUuid(UUID uuid) {
this.uuid = uuid;
}
public void setNumber(String number) { public void setNumber(String number) {
this.number = number; this.number = number;
} }
@ -247,6 +262,12 @@ public class Account implements Principal {
this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess; this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess;
} }
public boolean isFor(AmbiguousIdentifier identifier) {
if (identifier.hasUuid()) return identifier.getUuid().equals(uuid);
else if (identifier.hasNumber()) return identifier.getNumber().equals(number);
else throw new AssertionError();
}
// Principal implementation // Principal implementation
@Override @Override

View File

@ -26,6 +26,7 @@ import org.whispersystems.textsecuregcm.util.Util;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
@ -51,7 +52,7 @@ public class AccountCleaner implements AccountDatabaseCrawlerListener {
} }
@Override @Override
public void onCrawlChunk(Optional<String> fromNumber, List<Account> chunkAccounts) { public void onCrawlChunk(Optional<UUID> fromUuid, List<Account> chunkAccounts) {
int accountUpdateCount = 0; int accountUpdateCount = 0;
for (Account account : chunkAccounts) { for (Account account : chunkAccounts) {
if (needsExplicitRemoval(account)) { if (needsExplicitRemoval(account)) {
@ -74,7 +75,7 @@ public class AccountCleaner implements AccountDatabaseCrawlerListener {
} }
@Override @Override
public void onCrawlEnd(Optional<String> fromNumber) { public void onCrawlEnd(Optional<UUID> fromUuid) {
} }
private boolean needsExplicitRemoval(Account account) { private boolean needsExplicitRemoval(Account account) {

View File

@ -33,6 +33,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AccountDatabaseCrawler implements Managed, Runnable { public class AccountDatabaseCrawler implements Managed, Runnable {
private static final Logger logger = LoggerFactory.getLogger(AccountDatabaseCrawler.class); private static final Logger logger = LoggerFactory.getLogger(AccountDatabaseCrawler.class);
@ -91,6 +92,7 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
sleepWhileRunning(accelerated ? ACCELERATED_CHUNK_INTERVAL : chunkIntervalMs); sleepWhileRunning(accelerated ? ACCELERATED_CHUNK_INTERVAL : chunkIntervalMs);
} catch (Throwable t) { } catch (Throwable t) {
logger.warn("error in database crawl: ", t); logger.warn("error in database crawl: ", t);
Util.sleep(10000);
} }
} }
@ -120,26 +122,26 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
} }
private void processChunk() { private void processChunk() {
Optional<String> fromNumber = cache.getLastNumber(); Optional<UUID> fromUuid = cache.getLastUuid();
if (!fromNumber.isPresent()) { if (!fromUuid.isPresent()) {
listeners.forEach(listener -> { listener.onCrawlStart(); }); listeners.forEach(AccountDatabaseCrawlerListener::onCrawlStart);
} }
List<Account> chunkAccounts = readChunk(fromNumber, chunkSize); List<Account> chunkAccounts = readChunk(fromUuid, chunkSize);
if (chunkAccounts.isEmpty()) { if (chunkAccounts.isEmpty()) {
listeners.forEach(listener -> { listener.onCrawlEnd(fromNumber); }); listeners.forEach(listener -> listener.onCrawlEnd(fromUuid));
cache.setLastNumber(Optional.empty()); cache.setLastUuid(Optional.empty());
cache.clearAccelerate(); cache.clearAccelerate();
} else { } else {
try { try {
for (AccountDatabaseCrawlerListener listener : listeners) { for (AccountDatabaseCrawlerListener listener : listeners) {
listener.onCrawlChunk(fromNumber, chunkAccounts); listener.onCrawlChunk(fromUuid, chunkAccounts);
} }
cache.setLastNumber(Optional.of(chunkAccounts.get(chunkAccounts.size() - 1).getNumber())); cache.setLastUuid(Optional.of(chunkAccounts.get(chunkAccounts.size() - 1).getUuid()));
} catch (AccountDatabaseCrawlerRestartException e) { } catch (AccountDatabaseCrawlerRestartException e) {
cache.setLastNumber(Optional.empty()); cache.setLastUuid(Optional.empty());
cache.clearAccelerate(); cache.clearAccelerate();
} }
@ -147,12 +149,12 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
} }
private List<Account> readChunk(Optional<String> fromNumber, int chunkSize) { private List<Account> readChunk(Optional<UUID> fromUuid, int chunkSize) {
try (Timer.Context timer = readChunkTimer.time()) { try (Timer.Context timer = readChunkTimer.time()) {
List<Account> chunkAccounts; List<Account> chunkAccounts;
if (fromNumber.isPresent()) { if (fromUuid.isPresent()) {
chunkAccounts = accounts.getAllFrom(fromNumber.get(), chunkSize); chunkAccounts = accounts.getAllFrom(fromUuid.get(), chunkSize);
} else { } else {
chunkAccounts = accounts.getAllFrom(chunkSize); chunkAccounts = accounts.getAllFrom(chunkSize);
} }

View File

@ -1,4 +1,4 @@
/** /*
* Copyright (C) 2018 Open WhisperSystems * Copyright (C) 2018 Open WhisperSystems
* <p> * <p>
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
@ -18,17 +18,20 @@ package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.redis.LuaScript; import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import redis.clients.jedis.Jedis;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import redis.clients.jedis.Jedis;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AccountDatabaseCrawlerCache { public class AccountDatabaseCrawlerCache {
private static final String ACTIVE_WORKER_KEY = "account_database_crawler_cache_active_worker"; private static final String ACTIVE_WORKER_KEY = "account_database_crawler_cache_active_worker";
private static final String LAST_NUMBER_KEY = "account_database_crawler_cache_last_number"; private static final String LAST_UUID_KEY = "account_database_crawler_cache_last_uuid";
private static final String ACCELERATE_KEY = "account_database_crawler_cache_accelerate"; private static final String ACCELERATE_KEY = "account_database_crawler_cache_accelerate";
private static final long LAST_NUMBER_TTL_MS = 86400_000L; private static final long LAST_NUMBER_TTL_MS = 86400_000L;
@ -65,18 +68,21 @@ public class AccountDatabaseCrawlerCache {
luaScript.execute(keys, args); luaScript.execute(keys, args);
} }
public Optional<String> getLastNumber() { public Optional<UUID> getLastUuid() {
try (Jedis jedis = jedisPool.getWriteResource()) { try (Jedis jedis = jedisPool.getWriteResource()) {
return Optional.ofNullable(jedis.get(LAST_NUMBER_KEY)); String lastUuidString = jedis.get(LAST_UUID_KEY);
if (lastUuidString == null) return Optional.empty();
else return Optional.of(UUID.fromString(lastUuidString));
} }
} }
public void setLastNumber(Optional<String> lastNumber) { public void setLastUuid(Optional<UUID> lastUuid) {
try (Jedis jedis = jedisPool.getWriteResource()) { try (Jedis jedis = jedisPool.getWriteResource()) {
if (lastNumber.isPresent()) { if (lastUuid.isPresent()) {
jedis.psetex(LAST_NUMBER_KEY, LAST_NUMBER_TTL_MS, lastNumber.get()); jedis.psetex(LAST_UUID_KEY, LAST_NUMBER_TTL_MS, lastUuid.get().toString());
} else { } else {
jedis.del(LAST_NUMBER_KEY); jedis.del(LAST_UUID_KEY);
} }
} }
} }

View File

@ -18,9 +18,11 @@ package org.whispersystems.textsecuregcm.storage;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public interface AccountDatabaseCrawlerListener { public interface AccountDatabaseCrawlerListener {
void onCrawlStart(); void onCrawlStart();
void onCrawlChunk(Optional<String> fromNumber, List<Account> chunkAccounts) throws AccountDatabaseCrawlerRestartException; void onCrawlChunk(Optional<UUID> fromUuid, List<Account> chunkAccounts) throws AccountDatabaseCrawlerRestartException;
void onCrawlEnd(Optional<String> fromNumber); void onCrawlEnd(Optional<UUID> fromUuid);
} }

View File

@ -28,12 +28,14 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
public class Accounts { public class Accounts {
public static final String ID = "id"; public static final String ID = "id";
public static final String UID = "uuid";
public static final String NUMBER = "number"; public static final String NUMBER = "number";
public static final String DATA = "data"; public static final String DATA = "data";
@ -42,7 +44,8 @@ public class Accounts {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer createTimer = metricRegistry.timer(name(Accounts.class, "create" )); private final Timer createTimer = metricRegistry.timer(name(Accounts.class, "create" ));
private final Timer updateTimer = metricRegistry.timer(name(Accounts.class, "update" )); private final Timer updateTimer = metricRegistry.timer(name(Accounts.class, "update" ));
private final Timer getTimer = metricRegistry.timer(name(Accounts.class, "get")); private final Timer getByNumberTimer = metricRegistry.timer(name(Accounts.class, "getByNumber" ));
private final Timer getByUuidTimer = metricRegistry.timer(name(Accounts.class, "getByUuid" ));
private final Timer getAllFromTimer = metricRegistry.timer(name(Accounts.class, "getAllFrom" )); private final Timer getAllFromTimer = metricRegistry.timer(name(Accounts.class, "getAllFrom" ));
private final Timer getAllFromOffsetTimer = metricRegistry.timer(name(Accounts.class, "getAllFromOffset")); private final Timer getAllFromOffsetTimer = metricRegistry.timer(name(Accounts.class, "getAllFromOffset"));
private final Timer vacuumTimer = metricRegistry.timer(name(Accounts.class, "vacuum" )); private final Timer vacuumTimer = metricRegistry.timer(name(Accounts.class, "vacuum" ));
@ -57,16 +60,15 @@ public class Accounts {
public boolean create(Account account) { public boolean create(Account account) {
return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> {
try (Timer.Context ignored = createTimer.time()) { try (Timer.Context ignored = createTimer.time()) {
int rows = handle.createUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number") UUID uuid = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET data = EXCLUDED.data RETURNING uuid")
.bind("number", account.getNumber())
.execute();
handle.createUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))")
.bind("number", account.getNumber()) .bind("number", account.getNumber())
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account)) .bind("data", mapper.writeValueAsString(account))
.execute(); .mapTo(UUID.class)
.findOnly();
return rows == 0; account.setUuid(uuid);
return uuid.equals(account.getUuid());
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
@ -76,8 +78,8 @@ public class Accounts {
public void update(Account account) { public void update(Account account) {
database.use(jdbi -> jdbi.useHandle(handle -> { database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = updateTimer.time()) { try (Timer.Context ignored = updateTimer.time()) {
handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + NUMBER + " = :number") handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + UID + " = :uuid")
.bind("number", account.getNumber()) .bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account)) .bind("data", mapper.writeValueAsString(account))
.execute(); .execute();
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
@ -88,7 +90,7 @@ public class Accounts {
public Optional<Account> get(String number) { public Optional<Account> get(String number) {
return database.with(jdbi -> jdbi.withHandle(handle -> { return database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context ignored = getTimer.time()) { try (Timer.Context ignored = getByNumberTimer.time()) {
return handle.createQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number") return handle.createQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number")
.bind("number", number) .bind("number", number)
.mapTo(Account.class) .mapTo(Account.class)
@ -97,10 +99,21 @@ public class Accounts {
})); }));
} }
public List<Account> getAllFrom(String from, int length) { public Optional<Account> get(UUID uuid) {
return database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context ignored = getByUuidTimer.time()) {
return handle.createQuery("SELECT * FROM accounts WHERE " + UID + " = :uuid")
.bind("uuid", uuid)
.mapTo(Account.class)
.findFirst();
}
}));
}
public List<Account> getAllFrom(UUID from, int length) {
return database.with(jdbi -> jdbi.withHandle(handle -> { return database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context ignored = getAllFromOffsetTimer.time()) { try (Timer.Context ignored = getAllFromOffsetTimer.time()) {
return handle.createQuery("SELECT * FROM accounts WHERE " + NUMBER + " > :from ORDER BY " + NUMBER + " LIMIT :limit") return handle.createQuery("SELECT * FROM accounts WHERE " + UID + " > :from ORDER BY " + UID + " LIMIT :limit")
.bind("from", from) .bind("from", from)
.bind("limit", length) .bind("limit", length)
.mapTo(Account.class) .mapTo(Account.class)
@ -112,7 +125,7 @@ public class Accounts {
public List<Account> getAllFrom(int length) { public List<Account> getAllFrom(int length) {
return database.with(jdbi -> jdbi.withHandle(handle -> { return database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context ignored = getAllFromTimer.time()) { try (Timer.Context ignored = getAllFromTimer.time()) {
return handle.createQuery("SELECT * FROM accounts ORDER BY " + NUMBER + " LIMIT :limit") return handle.createQuery("SELECT * FROM accounts ORDER BY " + UID + " LIMIT :limit")
.bind("limit", length) .bind("limit", length)
.mapTo(Account.class) .mapTo(Account.class)
.list(); .list();

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
@ -33,6 +34,7 @@ import org.whispersystems.textsecuregcm.util.Util;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import redis.clients.jedis.Jedis; import redis.clients.jedis.Jedis;
@ -43,10 +45,12 @@ public class AccountsManager {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer createTimer = metricRegistry.timer(name(AccountsManager.class, "create" )); private static final Timer createTimer = metricRegistry.timer(name(AccountsManager.class, "create" ));
private static final Timer updateTimer = metricRegistry.timer(name(AccountsManager.class, "update" )); private static final Timer updateTimer = metricRegistry.timer(name(AccountsManager.class, "update" ));
private static final Timer getTimer = metricRegistry.timer(name(AccountsManager.class, "get" )); private static final Timer getByNumberTimer = metricRegistry.timer(name(AccountsManager.class, "getByNumber"));
private static final Timer getByUuidTimer = metricRegistry.timer(name(AccountsManager.class, "getByUuid" ));
private static final Timer redisSetTimer = metricRegistry.timer(name(AccountsManager.class, "redisSet" )); private static final Timer redisSetTimer = metricRegistry.timer(name(AccountsManager.class, "redisSet" ));
private static final Timer redisGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisGet" )); private static final Timer redisNumberGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisNumberGet"));
private static final Timer redisUuidGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUuidGet" ));
private final Logger logger = LoggerFactory.getLogger(AccountsManager.class); private final Logger logger = LoggerFactory.getLogger(AccountsManager.class);
@ -65,7 +69,7 @@ public class AccountsManager {
public boolean create(Account account) { public boolean create(Account account) {
try (Timer.Context context = createTimer.time()) { try (Timer.Context context = createTimer.time()) {
boolean freshUser = databaseCreate(account); boolean freshUser = databaseCreate(account);
redisSet(account.getNumber(), account, false); redisSet(account);
updateDirectory(account); updateDirectory(account);
return freshUser; return freshUser;
@ -74,31 +78,51 @@ public class AccountsManager {
public void update(Account account) { public void update(Account account) {
try (Timer.Context context = updateTimer.time()) { try (Timer.Context context = updateTimer.time()) {
redisSet(account.getNumber(), account, false); redisSet(account);
databaseUpdate(account); databaseUpdate(account);
updateDirectory(account); updateDirectory(account);
} }
} }
public Optional<Account> get(AmbiguousIdentifier identifier) {
if (identifier.hasNumber()) return get(identifier.getNumber());
else if (identifier.hasUuid()) return get(identifier.getUuid());
else throw new AssertionError();
}
public Optional<Account> get(String number) { public Optional<Account> get(String number) {
try (Timer.Context context = getTimer.time()) { try (Timer.Context context = getByNumberTimer.time()) {
Optional<Account> account = redisGet(number); Optional<Account> account = redisGet(number);
if (!account.isPresent()) { if (!account.isPresent()) {
account = databaseGet(number); account = databaseGet(number);
account.ifPresent(value -> redisSet(number, value, true)); account.ifPresent(value -> redisSet(value));
} }
return account; return account;
} }
} }
public Optional<Account> get(UUID uuid) {
try (Timer.Context context = getByUuidTimer.time()) {
Optional<Account> account = redisGet(uuid);
if (!account.isPresent()) {
account = databaseGet(uuid);
account.ifPresent(value -> redisSet(value));
}
return account;
}
}
public List<Account> getAllFrom(int length) { public List<Account> getAllFrom(int length) {
return accounts.getAllFrom(length); return accounts.getAllFrom(length);
} }
public List<Account> getAllFrom(String number, int length) { public List<Account> getAllFrom(UUID uuid, int length) {
return accounts.getAllFrom(number, length); return accounts.getAllFrom(uuid, length);
} }
private void updateDirectory(Account account) { private void updateDirectory(Account account) {
@ -111,15 +135,20 @@ public class AccountsManager {
} }
} }
private String getKey(String number) { private String getAccountMapKey(String number) {
return Account.class.getSimpleName() + Account.MEMCACHE_VERION + number; return "AccountMap::" + number;
} }
private void redisSet(String number, Account account, boolean optional) { private String getAccountEntityKey(UUID uuid) {
return "Account::" + uuid.toString();
}
private void redisSet(Account account) {
try (Jedis jedis = cacheClient.getWriteResource(); try (Jedis jedis = cacheClient.getWriteResource();
Timer.Context timer = redisSetTimer.time()) Timer.Context timer = redisSetTimer.time())
{ {
jedis.set(getKey(number), mapper.writeValueAsString(account)); jedis.set(getAccountMapKey(account.getNumber()), account.getUuid().toString());
jedis.set(getAccountEntityKey(account.getUuid()), mapper.writeValueAsString(account));
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new IllegalStateException(e); throw new IllegalStateException(e);
} }
@ -127,20 +156,14 @@ public class AccountsManager {
private Optional<Account> redisGet(String number) { private Optional<Account> redisGet(String number) {
try (Jedis jedis = cacheClient.getReadResource(); try (Jedis jedis = cacheClient.getReadResource();
Timer.Context timer = redisGetTimer.time()) Timer.Context timer = redisNumberGetTimer.time())
{ {
String json = jedis.get(getKey(number)); String uuid = jedis.get(getAccountMapKey(number));
if (json != null) { if (uuid != null) return redisGet(UUID.fromString(uuid));
Account account = mapper.readValue(json, Account.class); else return Optional.empty();
account.setNumber(number); } catch (IllegalArgumentException e) {
logger.warn("Deserialization error", e);
return Optional.of(account);
}
return Optional.empty();
} catch (IOException e) {
logger.warn("AccountsManager", "Deserialization error", e);
return Optional.empty(); return Optional.empty();
} catch (JedisException e) { } catch (JedisException e) {
logger.warn("Redis failure", e); logger.warn("Redis failure", e);
@ -148,10 +171,38 @@ public class AccountsManager {
} }
} }
private Optional<Account> redisGet(UUID uuid) {
try (Jedis jedis = cacheClient.getReadResource();
Timer.Context timer = redisUuidGetTimer.time())
{
String json = jedis.get(getAccountEntityKey(uuid));
if (json != null) {
Account account = mapper.readValue(json, Account.class);
account.setUuid(uuid);
return Optional.of(account);
}
return Optional.empty();
} catch (IOException e) {
logger.warn("Deserialization error", e);
return Optional.empty();
} catch (JedisException e) {
logger.warn("Redis failure", e);
return Optional.empty();
}
}
private Optional<Account> databaseGet(String number) { private Optional<Account> databaseGet(String number) {
return accounts.get(number); return accounts.get(number);
} }
private Optional<Account> databaseGet(UUID uuid) {
return accounts.get(uuid);
}
private boolean databaseCreate(Account account) { private boolean databaseCreate(Account account) {
return accounts.create(account); return accounts.create(account);
} }

View File

@ -1,4 +1,4 @@
/** /*
* Copyright (C) 2018 Open WhisperSystems * Copyright (C) 2018 Open WhisperSystems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
@ -20,22 +20,22 @@ import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import io.dropwizard.metrics.MetricsFactory;
import io.dropwizard.metrics.ReporterFactory;
import org.whispersystems.textsecuregcm.entities.ActiveUserTally; import org.whispersystems.textsecuregcm.entities.ActiveUserTally;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import redis.clients.jedis.Jedis;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern; import io.dropwizard.metrics.MetricsFactory;
import io.dropwizard.metrics.ReporterFactory;
import redis.clients.jedis.Jedis;
public class ActiveUserCounter implements AccountDatabaseCrawlerListener { public class ActiveUserCounter implements AccountDatabaseCrawlerListener {
@ -44,8 +44,6 @@ public class ActiveUserCounter implements AccountDatabaseCrawlerListener {
private static final String PLATFORM_IOS = "ios"; private static final String PLATFORM_IOS = "ios";
private static final String PLATFORM_ANDROID = "android"; private static final String PLATFORM_ANDROID = "android";
private static final String FIRST_FROM_NUMBER = "+";
private static final String INTERVALS[] = {"daily", "weekly", "monthly", "quarterly", "yearly"}; private static final String INTERVALS[] = {"daily", "weekly", "monthly", "quarterly", "yearly"};
private final MetricsFactory metricsFactory; private final MetricsFactory metricsFactory;
@ -64,7 +62,8 @@ public class ActiveUserCounter implements AccountDatabaseCrawlerListener {
} }
} }
public void onCrawlChunk(Optional<String> fromNumber, List<Account> chunkAccounts) { @Override
public void onCrawlChunk(Optional<UUID> fromNumber, List<Account> chunkAccounts) {
long nowDays = TimeUnit.MILLISECONDS.toDays(System.currentTimeMillis()); long nowDays = TimeUnit.MILLISECONDS.toDays(System.currentTimeMillis());
long agoMs[] = {TimeUnit.DAYS.toMillis(nowDays - 1), long agoMs[] = {TimeUnit.DAYS.toMillis(nowDays - 1),
TimeUnit.DAYS.toMillis(nowDays - 7), TimeUnit.DAYS.toMillis(nowDays - 7),
@ -107,23 +106,21 @@ public class ActiveUserCounter implements AccountDatabaseCrawlerListener {
} }
} }
incrementTallies(fromNumber.orElse(FIRST_FROM_NUMBER), platformIncrements, countryIncrements); incrementTallies(fromNumber.orElse(UUID.randomUUID()), platformIncrements, countryIncrements);
} }
public void onCrawlEnd(Optional<String> fromNumber) { @Override
public void onCrawlEnd(Optional<UUID> fromNumber) {
MetricRegistry metrics = new MetricRegistry(); MetricRegistry metrics = new MetricRegistry();
long intervalTallies[] = new long[INTERVALS.length]; long intervalTallies[] = new long[INTERVALS.length];
ActiveUserTally activeUserTally = getFinalTallies(); ActiveUserTally activeUserTally = getFinalTallies();
Map<String, long[]> platforms = activeUserTally.getPlatforms(); Map<String, long[]> platforms = activeUserTally.getPlatforms();
platforms.forEach((platform, platformTallies) -> { platforms.forEach((platform, platformTallies) -> {
for (int i = 0; i < INTERVALS.length; i++) { for (int i = 0; i < INTERVALS.length; i++) {
final long tally = platformTallies[i]; final long tally = platformTallies[i];
metrics.register(metricKey(platform, INTERVALS[i]), metrics.register(metricKey(platform, INTERVALS[i]),
new Gauge<Long>() { (Gauge<Long>) () -> tally);
@Override
public Long getValue() { return tally; }
});
intervalTallies[i] += tally; intervalTallies[i] += tally;
} }
}); });
@ -133,21 +130,16 @@ public class ActiveUserCounter implements AccountDatabaseCrawlerListener {
for (int i = 0; i < INTERVALS.length; i++) { for (int i = 0; i < INTERVALS.length; i++) {
final long tally = countryTallies[i]; final long tally = countryTallies[i];
metrics.register(metricKey(country, INTERVALS[i]), metrics.register(metricKey(country, INTERVALS[i]),
new Gauge<Long>() { (Gauge<Long>) () -> tally);
@Override
public Long getValue() { return tally; }
});
} }
}); });
for (int i = 0; i < INTERVALS.length; i++) { for (int i = 0; i < INTERVALS.length; i++) {
final long intervalTotal = intervalTallies[i]; final long intervalTotal = intervalTallies[i];
metrics.register(metricKey(INTERVALS[i]), metrics.register(metricKey(INTERVALS[i]),
new Gauge<Long>() { (Gauge<Long>) () -> intervalTotal);
@Override
public Long getValue() { return intervalTotal; }
});
} }
for (ReporterFactory reporterFactory : metricsFactory.getReporters()) { for (ReporterFactory reporterFactory : metricsFactory.getReporters()) {
reporterFactory.build(metrics).report(); reporterFactory.build(metrics).report();
} }
@ -162,22 +154,25 @@ public class ActiveUserCounter implements AccountDatabaseCrawlerListener {
return tally; return tally;
} }
private void incrementTallies(String fromNumber, Map<String, long[]> platformIncrements, Map<String, long[]> countryIncrements) { private void incrementTallies(UUID fromUuid, Map<String, long[]> platformIncrements, Map<String, long[]> countryIncrements) {
try (Jedis jedis = jedisPool.getWriteResource()) { try (Jedis jedis = jedisPool.getWriteResource()) {
String tallyValue = jedis.get(TALLY_KEY); String tallyValue = jedis.get(TALLY_KEY);
ActiveUserTally activeUserTally; ActiveUserTally activeUserTally;
if (tallyValue == null) { if (tallyValue == null) {
activeUserTally = new ActiveUserTally(fromNumber, platformIncrements, countryIncrements); activeUserTally = new ActiveUserTally(fromUuid, platformIncrements, countryIncrements);
} else { } else {
activeUserTally = mapper.readValue(tallyValue, ActiveUserTally.class); activeUserTally = mapper.readValue(tallyValue, ActiveUserTally.class);
if (activeUserTally.getFromNumber() != fromNumber) {
activeUserTally.setFromNumber(fromNumber); if (!fromUuid.equals(activeUserTally.getFromUuid())) {
activeUserTally.setFromUuid(fromUuid);
Map<String, long[]> platformTallies = activeUserTally.getPlatforms(); Map<String, long[]> platformTallies = activeUserTally.getPlatforms();
addTallyMaps(platformTallies, platformIncrements); addTallyMaps(platformTallies, platformIncrements);
Map<String, long[]> countryTallies = activeUserTally.getCountries(); Map<String, long[]> countryTallies = activeUserTally.getCountries();
addTallyMaps(countryTallies, countryIncrements); addTallyMaps(countryTallies, countryIncrements);
} }
} }
jedis.set(TALLY_KEY, mapper.writeValueAsString(activeUserTally)); jedis.set(TALLY_KEY, mapper.writeValueAsString(activeUserTally));
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);

View File

@ -33,6 +33,7 @@ import javax.ws.rs.ProcessingException;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
@ -55,18 +56,16 @@ public class DirectoryReconciler implements AccountDatabaseCrawlerListener {
public void onCrawlStart() { } public void onCrawlStart() { }
public void onCrawlEnd(Optional<String> fromNumber) { public void onCrawlEnd(Optional<UUID> fromUuid) {
DirectoryReconciliationRequest request = new DirectoryReconciliationRequest(fromUuid.orElse(null), null, Collections.emptyList());
DirectoryReconciliationRequest request = new DirectoryReconciliationRequest(fromNumber.orElse(null), null, Collections.emptyList());
DirectoryReconciliationResponse response = sendChunk(request); DirectoryReconciliationResponse response = sendChunk(request);
} }
public void onCrawlChunk(Optional<String> fromNumber, List<Account> chunkAccounts) throws AccountDatabaseCrawlerRestartException { public void onCrawlChunk(Optional<UUID> fromUuid, List<Account> chunkAccounts) throws AccountDatabaseCrawlerRestartException {
updateDirectoryCache(chunkAccounts); updateDirectoryCache(chunkAccounts);
DirectoryReconciliationRequest request = createChunkRequest(fromNumber, chunkAccounts); DirectoryReconciliationRequest request = createChunkRequest(fromUuid, chunkAccounts);
DirectoryReconciliationResponse response = sendChunk(request); DirectoryReconciliationResponse response = sendChunk(request);
if (response.getStatus() == DirectoryReconciliationResponse.Status.MISSING) { if (response.getStatus() == DirectoryReconciliationResponse.Status.MISSING) {
throw new AccountDatabaseCrawlerRestartException("directory reconciler missing"); throw new AccountDatabaseCrawlerRestartException("directory reconciler missing");
@ -93,19 +92,19 @@ public class DirectoryReconciler implements AccountDatabaseCrawlerListener {
} }
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private DirectoryReconciliationRequest createChunkRequest(Optional<String> fromNumber, List<Account> accounts) { private DirectoryReconciliationRequest createChunkRequest(Optional<UUID> fromUuid, List<Account> accounts) {
List<String> numbers = accounts.stream() List<String> numbers = accounts.stream()
.filter(Account::isEnabled) .filter(Account::isEnabled)
.map(Account::getNumber) .map(Account::getNumber)
.collect(Collectors.toList()); .collect(Collectors.toList());
Optional<String> toNumber = Optional.empty(); Optional<UUID> toUuid = Optional.empty();
if (!accounts.isEmpty()) { if (!accounts.isEmpty()) {
toNumber = Optional.of(accounts.get(accounts.size() - 1).getNumber()); toUuid = Optional.of(accounts.get(accounts.size() - 1).getUuid());
} }
return new DirectoryReconciliationRequest(fromNumber.orElse(null), toNumber.orElse(null), numbers); return new DirectoryReconciliationRequest(fromUuid.orElse(null), toUuid.orElse(null), numbers);
} }
private DirectoryReconciliationResponse sendChunk(DirectoryReconciliationRequest request) { private DirectoryReconciliationResponse sendChunk(DirectoryReconciliationRequest request) {

View File

@ -25,6 +25,7 @@ public class Messages {
public static final String TIMESTAMP = "timestamp"; public static final String TIMESTAMP = "timestamp";
public static final String SERVER_TIMESTAMP = "server_timestamp"; public static final String SERVER_TIMESTAMP = "server_timestamp";
public static final String SOURCE = "source"; public static final String SOURCE = "source";
public static final String SOURCE_UUID = "source_uuid";
public static final String SOURCE_DEVICE = "source_device"; public static final String SOURCE_DEVICE = "source_device";
public static final String DESTINATION = "destination"; public static final String DESTINATION = "destination";
public static final String DESTINATION_DEVICE = "destination_device"; public static final String DESTINATION_DEVICE = "destination_device";
@ -51,8 +52,8 @@ public class Messages {
public void store(UUID guid, Envelope message, String destination, long destinationDevice) { public void store(UUID guid, Envelope message, String destination, long destinationDevice) {
database.use(jdbi ->jdbi.useHandle(handle -> { database.use(jdbi ->jdbi.useHandle(handle -> {
try (Timer.Context ignored = storeTimer.time()) { try (Timer.Context ignored = storeTimer.time()) {
handle.createUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + handle.createUpdate("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_UUID + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " +
"VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_device, :destination, :destination_device, :message, :content)") "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_uuid, :source_device, :destination, :destination_device, :message, :content)")
.bind("guid", guid) .bind("guid", guid)
.bind("destination", destination) .bind("destination", destination)
.bind("destination_device", destinationDevice) .bind("destination_device", destinationDevice)
@ -61,6 +62,7 @@ public class Messages {
.bind("timestamp", message.getTimestamp()) .bind("timestamp", message.getTimestamp())
.bind("server_timestamp", message.getServerTimestamp()) .bind("server_timestamp", message.getServerTimestamp())
.bind("source", message.hasSource() ? message.getSource() : null) .bind("source", message.hasSource() ? message.getSource() : null)
.bind("source_uuid", message.hasSourceUuid() ? UUID.fromString(message.getSourceUuid()) : null)
.bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null) .bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null)
.bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null) .bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null)
.bind("content", message.hasContent() ? message.getContent().toByteArray() : null) .bind("content", message.hasContent() ? message.getContent().toByteArray() : null)

View File

@ -203,6 +203,7 @@ public class MessagesCache implements Managed {
envelope.getRelay(), envelope.getRelay(),
envelope.getTimestamp(), envelope.getTimestamp(),
envelope.getSource(), envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(), envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null, envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null, envelope.hasContent() ? envelope.getContent().toByteArray() : null,

View File

@ -1,18 +0,0 @@
package org.whispersystems.textsecuregcm.storage;
public class PublicAccount extends Account {
public PublicAccount() {}
public PublicAccount(Account account) {
setIdentityKey(account.getIdentityKey());
setUnidentifiedAccessKey(account.getUnidentifiedAccessKey().orElse(null));
setUnrestrictedUnidentifiedAccess(account.isUnrestrictedUnidentifiedAccess());
setAvatar(account.getAvatar());
setProfileName(account.getProfileName());
setPin("******");
account.getDevices().forEach(this::addDevice);
}
}

View File

@ -9,6 +9,7 @@ import org.whispersystems.textsecuregcm.util.Util;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
@ -31,7 +32,7 @@ public class PushFeedbackProcessor implements AccountDatabaseCrawlerListener {
public void onCrawlStart() {} public void onCrawlStart() {}
@Override @Override
public void onCrawlChunk(Optional<String> fromNumber, List<Account> chunkAccounts) { public void onCrawlChunk(Optional<UUID> fromUuid, List<Account> chunkAccounts) {
for (Account account : chunkAccounts) { for (Account account : chunkAccounts) {
boolean update = false; boolean update = false;
@ -65,5 +66,5 @@ public class PushFeedbackProcessor implements AccountDatabaseCrawlerListener {
} }
@Override @Override
public void onCrawlEnd(Optional<String> fromNumber) {} public void onCrawlEnd(Optional<UUID> toUuid) {}
} }

View File

@ -10,6 +10,7 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException; import java.io.IOException;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.UUID;
public class AccountRowMapper implements RowMapper<Account> { public class AccountRowMapper implements RowMapper<Account> {
@ -20,6 +21,7 @@ public class AccountRowMapper implements RowMapper<Account> {
try { try {
Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class); Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class);
account.setNumber(resultSet.getString(Accounts.NUMBER)); account.setNumber(resultSet.getString(Accounts.NUMBER));
account.setUuid(UUID.fromString(resultSet.getString(Accounts.UID)));
return account; return account;
} catch (IOException e) { } catch (IOException e) {
throw new SQLException(e); throw new SQLException(e);

View File

@ -16,6 +16,7 @@ public class OutgoingMessageEntityRowMapper implements RowMapper<OutgoingMessage
int type = resultSet.getInt(Messages.TYPE); int type = resultSet.getInt(Messages.TYPE);
byte[] legacyMessage = resultSet.getBytes(Messages.MESSAGE); byte[] legacyMessage = resultSet.getBytes(Messages.MESSAGE);
String guid = resultSet.getString(Messages.GUID); String guid = resultSet.getString(Messages.GUID);
String sourceUuid = resultSet.getString(Messages.SOURCE_UUID);
if (type == Envelope.Type.RECEIPT_VALUE && legacyMessage == null) { if (type == Envelope.Type.RECEIPT_VALUE && legacyMessage == null) {
/// XXX - REMOVE AFTER 10/01/15 /// XXX - REMOVE AFTER 10/01/15
@ -29,6 +30,7 @@ public class OutgoingMessageEntityRowMapper implements RowMapper<OutgoingMessage
resultSet.getString(Messages.RELAY), resultSet.getString(Messages.RELAY),
resultSet.getLong(Messages.TIMESTAMP), resultSet.getLong(Messages.TIMESTAMP),
resultSet.getString(Messages.SOURCE), resultSet.getString(Messages.SOURCE),
sourceUuid == null ? null : UUID.fromString(sourceUuid),
resultSet.getInt(Messages.SOURCE_DEVICE), resultSet.getInt(Messages.SOURCE_DEVICE),
legacyMessage, legacyMessage,
resultSet.getBytes(Messages.CONTENT), resultSet.getBytes(Messages.CONTENT),

View File

@ -197,4 +197,14 @@
<dropNotNullConstraint tableName="pending_accounts" columnName="verification_code"/> <dropNotNullConstraint tableName="pending_accounts" columnName="verification_code"/>
</changeSet> </changeSet>
<changeSet id="7" author="moxie">
<addColumn tableName="accounts">
<column name="uuid" type="uuid"/>
</addColumn>
</changeSet>
<changeSet id="8" author="moxie" runInTransaction="false">
<sql>CREATE UNIQUE INDEX CONCURRENTLY uuid_index ON accounts (uuid);</sql>
</changeSet>
</databaseChangeLog> </databaseChangeLog>

View File

@ -121,4 +121,10 @@
<sql>CREATE INDEX CONCURRENTLY guid_index ON messages (guid);</sql> <sql>CREATE INDEX CONCURRENTLY guid_index ON messages (guid);</sql>
</changeSet> </changeSet>
<changeSet id="13" author="moxie">
<addColumn tableName="messages">
<column name="source_uuid" type="uuid"/>
</addColumn>
</changeSet>
</databaseChangeLog> </databaseChangeLog>

View File

@ -14,9 +14,10 @@ import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.AccountCreationResult;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeprecatedPin; import org.whispersystems.textsecuregcm.entities.DeprecatedPin;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.RegistrationLock; import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
@ -463,15 +464,15 @@ public class AccountControllerTest {
@Test @Test
public void testVerifyCode() throws Exception { public void testVerifyCode() throws Exception {
Response response = AccountCreationResult result =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234")) .target(String.format("/v1/accounts/code/%s", "1234"))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar")) .header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes("keykeykeykey", false, 2222, null), .put(Entity.entity(new AccountAttributes("keykeykeykey", false, 2222, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(response.getStatus()).isEqualTo(204); assertThat(result.getUuid()).isNotNull();
verify(accountsManager, times(1)).create(isA(Account.class)); verify(accountsManager, times(1)).create(isA(Account.class));
verify(directoryQueue, times(1)).deleteRegisteredUser(eq(SENDER)); verify(directoryQueue, times(1)).deleteRegisteredUser(eq(SENDER));
@ -509,30 +510,30 @@ public class AccountControllerTest {
@Test @Test
public void testVerifyPin() throws Exception { public void testVerifyPin() throws Exception {
Response response = AccountCreationResult result =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "333333")) .target(String.format("/v1/accounts/code/%s", "333333"))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar")) .header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes("keykeykeykey", false, 3333, "31337"), .put(Entity.entity(new AccountAttributes("keykeykeykey", false, 3333, "31337"),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(response.getStatus()).isEqualTo(204); assertThat(result.getUuid()).isNotNull();
verify(pinLimiter).validate(eq(SENDER_PIN)); verify(pinLimiter).validate(eq(SENDER_PIN));
} }
@Test @Test
public void testVerifyRegistrationLock() throws Exception { public void testVerifyRegistrationLock() throws Exception {
Response response = AccountCreationResult result =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "666666")) .target(String.format("/v1/accounts/code/%s", "666666"))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar")) .header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes("keykeykeykey", false, 3333, null, null, Hex.toStringCondensed(registration_lock_key)), .put(Entity.entity(new AccountAttributes("keykeykeykey", false, 3333, null, null, Hex.toStringCondensed(registration_lock_key)),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(response.getStatus()).isEqualTo(204); assertThat(result.getUuid()).isNotNull();
verify(pinLimiter).validate(eq(SENDER_REG_LOCK)); verify(pinLimiter).validate(eq(SENDER_REG_LOCK));
} }
@ -628,15 +629,15 @@ public class AccountControllerTest {
try { try {
when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(7)); when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(7));
Response response = AccountCreationResult result =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "444444")) .target(String.format("/v1/accounts/code/%s", "444444"))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_OVER_PIN, "bar")) .header("Authorization", AuthHelper.getAuthHeader(SENDER_OVER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes("keykeykeykey", false, 3333, null), .put(Entity.entity(new AccountAttributes("keykeykeykey", false, 3333, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(response.getStatus()).isEqualTo(204); assertThat(result.getUuid()).isNotNull();
} finally { } finally {
when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis()); when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis());
@ -666,7 +667,7 @@ public class AccountControllerTest {
resources.getJerseyTest() resources.getJerseyTest()
.target("/v1/accounts/registration_lock/") .target("/v1/accounts/registration_lock/")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.put(Entity.json(new RegistrationLock("1234567890123456789012345678901234567890123456789012345678901234"))); .put(Entity.json(new RegistrationLock("1234567890123456789012345678901234567890123456789012345678901234")));
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
@ -745,7 +746,6 @@ public class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
} }
@Test @Test
public void testSetGcmId() throws Exception { public void testSetGcmId() throws Exception {
Response response = Response response =
@ -761,6 +761,21 @@ public class AccountControllerTest {
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT)); verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
} }
@Test
public void testSetGcmIdByUuid() throws Exception {
Response response =
resources.getJerseyTest()
.target("/v1/accounts/gcm/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID.toString(), AuthHelper.DISABLED_PASSWORD))
.put(Entity.json(new GcmRegistrationId("z000")));
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
}
@Test @Test
public void testSetApnId() throws Exception { public void testSetApnId() throws Exception {
Response response = Response response =
@ -777,5 +792,32 @@ public class AccountControllerTest {
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT)); verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
} }
@Test
public void testSetApnIdByUuid() throws Exception {
Response response =
resources.getJerseyTest()
.target("/v1/accounts/apn/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID.toString(), AuthHelper.DISABLED_PASSWORD))
.put(Entity.json(new ApnRegistrationId("third", "fourth")));
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("third"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("fourth"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
}
@Test
public void testWhoAmI() {
AccountCreationResult response =
resources.getJerseyTest()
.target("/v1/accounts/whoami/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(AccountCreationResult.class);
assertThat(response.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
}
} }

View File

@ -75,6 +75,7 @@ public class CertificateControllerTest {
assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER); assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER);
assertEquals(certificate.getSenderDevice(), 1L); assertEquals(certificate.getSenderDevice(), 1L);
assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString());
assertTrue(Arrays.equals(certificate.getIdentityKey().toByteArray(), Base64.decode(AuthHelper.VALID_IDENTITY))); assertTrue(Arrays.equals(certificate.getIdentityKey().toByteArray(), Base64.decode(AuthHelper.VALID_IDENTITY)));
} }

View File

@ -6,6 +6,8 @@ import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.KeysController;
@ -32,6 +34,7 @@ import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule; import io.dropwizard.testing.junit.ResourceTestRule;
@ -41,7 +44,10 @@ import static org.mockito.Mockito.*;
public class KeyControllerTest { public class KeyControllerTest {
private static final String EXISTS_NUMBER = "+14152222222"; private static final String EXISTS_NUMBER = "+14152222222";
private static final UUID EXISTS_UUID = UUID.randomUUID();
private static String NOT_EXISTS_NUMBER = "+14152222220"; private static String NOT_EXISTS_NUMBER = "+14152222220";
private static UUID NOT_EXISTS_UUID = UUID.randomUUID();
private static int SAMPLE_REGISTRATION_ID = 999; private static int SAMPLE_REGISTRATION_ID = 999;
private static int SAMPLE_REGISTRATION_ID2 = 1002; private static int SAMPLE_REGISTRATION_ID2 = 1002;
@ -117,7 +123,14 @@ public class KeyControllerTest {
when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes())); when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes()));
when(accounts.get(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount)); when(accounts.get(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount));
when(accounts.get(EXISTS_UUID)).thenReturn(Optional.of(existsAccount));
when(accounts.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(EXISTS_NUMBER)))).thenReturn(Optional.of(existsAccount));
when(accounts.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(EXISTS_UUID)))).thenReturn(Optional.of(existsAccount));
when(accounts.get(NOT_EXISTS_NUMBER)).thenReturn(Optional.<Account>empty()); when(accounts.get(NOT_EXISTS_NUMBER)).thenReturn(Optional.<Account>empty());
when(accounts.get(NOT_EXISTS_UUID)).thenReturn(Optional.empty());
when(accounts.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(NOT_EXISTS_NUMBER)))).thenReturn(Optional.empty());
when(accounts.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(NOT_EXISTS_UUID)))).thenReturn(Optional.empty());
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
@ -141,7 +154,7 @@ public class KeyControllerTest {
} }
@Test @Test
public void validKeyStatusTestV2() throws Exception { public void validKeyStatusTestByNumberV2() throws Exception {
PreKeyCount result = resources.getJerseyTest() PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys") .target("/v2/keys")
.request() .request()
@ -155,7 +168,22 @@ public class KeyControllerTest {
} }
@Test @Test
public void getSignedPreKeyV2() throws Exception { public void validKeyStatusTestByUuidV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization",
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.get(PreKeyCount.class);
assertThat(result.getCount()).isEqualTo(4);
verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L));
}
@Test
public void getSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey result = resources.getJerseyTest() SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
.request() .request()
@ -168,7 +196,20 @@ public class KeyControllerTest {
} }
@Test @Test
public void putSignedPreKeyV2() throws Exception { public void getSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class);
assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature());
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey());
}
@Test
public void putSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz"); SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
@ -183,7 +224,23 @@ public class KeyControllerTest {
} }
@Test @Test
public void disabledPutSignedPreKeyV2() throws Exception { public void putSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9998, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
}
@Test
public void disabledPutSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz"); SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
@ -195,7 +252,20 @@ public class KeyControllerTest {
} }
@Test @Test
public void validSingleRequestTestV2() throws Exception { public void disabledPutSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID.toString(), AuthHelper.DISABLED_PASSWORD))
.put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(401);
}
@Test
public void validSingleRequestTestV2ByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER)) .target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request() .request()
@ -213,7 +283,26 @@ public class KeyControllerTest {
} }
@Test @Test
public void testUnidentifiedRequest() throws Exception { public void validSingleRequestTestV2ByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
verify(keys).get(eq(EXISTS_NUMBER), eq(1L));
verifyNoMoreInteractions(keys);
}
@Test
public void testUnidentifiedRequestByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER)) .target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request() .request()
@ -230,6 +319,25 @@ public class KeyControllerTest {
verifyNoMoreInteractions(keys); verifyNoMoreInteractions(keys);
} }
@Test
public void testUnidentifiedRequestByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID.toString()))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
verify(keys).get(eq(EXISTS_NUMBER), eq(1L));
verifyNoMoreInteractions(keys);
}
@Test @Test
public void testUnauthorizedUnidentifiedRequest() throws Exception { public void testUnauthorizedUnidentifiedRequest() throws Exception {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
@ -256,7 +364,7 @@ public class KeyControllerTest {
@Test @Test
public void validMultiRequestTestV2() throws Exception { public void validMultiRequestTestV2ByNumber() throws Exception {
PreKeyResponse results = resources.getJerseyTest() PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_NUMBER)) .target(String.format("/v2/keys/%s/*", EXISTS_NUMBER))
.request() .request()
@ -305,6 +413,57 @@ public class KeyControllerTest {
verifyNoMoreInteractions(keys); verifyNoMoreInteractions(keys);
} }
@Test
public void validMultiRequestTestV2ByUuid() throws Exception {
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY.getKeyId());
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY.getPublicKey());
assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY2.getKeyId());
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY2.getPublicKey());
assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY4.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY4.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
verify(keys).get(eq(EXISTS_NUMBER));
verifyNoMoreInteractions(keys);
}
@Test @Test
public void invalidRequestTestV2() throws Exception { public void invalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()

View File

@ -7,6 +7,8 @@ import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
@ -55,7 +57,10 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixtur
public class MessageControllerTest { public class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private final PushSender pushSender = mock(PushSender.class ); private final PushSender pushSender = mock(PushSender.class );
private final ReceiptSender receiptSender = mock(ReceiptSender.class); private final ReceiptSender receiptSender = mock(ReceiptSender.class);
@ -89,11 +94,13 @@ public class MessageControllerTest {
add(new Device(3, null, "foo", "bar", "baz", "isgcm", null, null, false, 444, null, System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31), System.currentTimeMillis(), "Test", true, 0)); add(new Device(3, null, "foo", "bar", "baz", "isgcm", null, null, false, 444, null, System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31), System.currentTimeMillis(), "Test", true, 0));
}}; }};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, singleDeviceList, "1234".getBytes()); Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, singleDeviceList, "1234".getBytes());
Account multiDeviceAccount = new Account(MULTI_DEVICE_RECIPIENT, multiDeviceList, "1234".getBytes()); Account multiDeviceAccount = new Account(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, multiDeviceList, "1234".getBytes());
when(accountsManager.get(eq(SINGLE_DEVICE_RECIPIENT))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.get(eq(SINGLE_DEVICE_RECIPIENT))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(SINGLE_DEVICE_RECIPIENT)))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.get(eq(MULTI_DEVICE_RECIPIENT))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.get(eq(MULTI_DEVICE_RECIPIENT))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(MULTI_DEVICE_RECIPIENT)))).thenReturn(Optional.of(multiDeviceAccount));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
} }
@ -240,11 +247,12 @@ public class MessageControllerTest {
final long timestampOne = 313377; final long timestampOne = 313377;
final long timestampTwo = 313388; final long timestampTwo = 313388;
final UUID uuidOne = UUID.randomUUID(); final UUID messageGuidOne = UUID.randomUUID();
final UUID sourceUuid = UUID.randomUUID();
List<OutgoingMessageEntity> messages = new LinkedList<OutgoingMessageEntity>() {{ List<OutgoingMessageEntity> messages = new LinkedList<>() {{
add(new OutgoingMessageEntity(1L, false, uuidOne, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null, 0)); add(new OutgoingMessageEntity(1L, false, messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", sourceUuid, 2, "hi there".getBytes(), null, 0));
add(new OutgoingMessageEntity(2L, false, null, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null, 0)); add(new OutgoingMessageEntity(2L, false, null, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", sourceUuid, 2, null, null, 0));
}}; }};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@ -254,7 +262,7 @@ public class MessageControllerTest {
OutgoingMessageEntityList response = OutgoingMessageEntityList response =
resources.getJerseyTest().target("/v1/messages/") resources.getJerseyTest().target("/v1/messages/")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.accept(MediaType.APPLICATION_JSON_TYPE) .accept(MediaType.APPLICATION_JSON_TYPE)
.get(OutgoingMessageEntityList.class); .get(OutgoingMessageEntityList.class);
@ -267,8 +275,11 @@ public class MessageControllerTest {
assertEquals(response.getMessages().get(0).getTimestamp(), timestampOne); assertEquals(response.getMessages().get(0).getTimestamp(), timestampOne);
assertEquals(response.getMessages().get(1).getTimestamp(), timestampTwo); assertEquals(response.getMessages().get(1).getTimestamp(), timestampTwo);
assertEquals(response.getMessages().get(0).getGuid(), uuidOne); assertEquals(response.getMessages().get(0).getGuid(), messageGuidOne);
assertEquals(response.getMessages().get(1).getGuid(), null); assertNull(response.getMessages().get(1).getGuid());
assertEquals(response.getMessages().get(0).getSourceUuid(), sourceUuid);
assertEquals(response.getMessages().get(1).getSourceUuid(), sourceUuid);
} }
@Test @Test
@ -277,8 +288,8 @@ public class MessageControllerTest {
final long timestampTwo = 313388; final long timestampTwo = 313388;
List<OutgoingMessageEntity> messages = new LinkedList<OutgoingMessageEntity>() {{ List<OutgoingMessageEntity> messages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(1L, false, UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null, 0)); add(new OutgoingMessageEntity(1L, false, UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", UUID.randomUUID(), 2, "hi there".getBytes(), null, 0));
add(new OutgoingMessageEntity(2L, false, UUID.randomUUID(), Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null, 0)); add(new OutgoingMessageEntity(2L, false, UUID.randomUUID(), Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", UUID.randomUUID(), 2, null, null, 0));
}}; }};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@ -288,7 +299,7 @@ public class MessageControllerTest {
Response response = Response response =
resources.getJerseyTest().target("/v1/messages/") resources.getJerseyTest().target("/v1/messages/")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.INVALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.INVALID_PASSWORD))
.accept(MediaType.APPLICATION_JSON_TYPE) .accept(MediaType.APPLICATION_JSON_TYPE)
.get(); .get();
@ -299,17 +310,19 @@ public class MessageControllerTest {
public synchronized void testDeleteMessages() throws Exception { public synchronized void testDeleteMessages() throws Exception {
long timestamp = System.currentTimeMillis(); long timestamp = System.currentTimeMillis();
UUID sourceUuid = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31337)) when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31337))
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null, .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null,
Envelope.Type.CIPHERTEXT_VALUE, Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp, null, timestamp,
"+14152222222", 1, "hi".getBytes(), null, 0))); "+14152222222", sourceUuid, 1, "hi".getBytes(), null, 0)));
when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31338)) when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31338))
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null, .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null,
Envelope.Type.RECEIPT_VALUE, Envelope.Type.RECEIPT_VALUE,
null, System.currentTimeMillis(), null, System.currentTimeMillis(),
"+14152222222", 1, null, null, 0))); "+14152222222", sourceUuid, 1, null, null, 0)));
when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31339)) when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31339))
@ -327,7 +340,7 @@ public class MessageControllerTest {
response = resources.getJerseyTest() response = resources.getJerseyTest()
.target(String.format("/v1/messages/%s/%d", "+14152222222", 31338)) .target(String.format("/v1/messages/%s/%d", "+14152222222", 31338))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID.toString(), AuthHelper.VALID_PASSWORD))
.delete(); .delete();
assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));

View File

@ -5,6 +5,8 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before; import org.junit.Before;
import org.junit.ClassRule; import org.junit.ClassRule;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.configuration.CdnConfiguration; import org.whispersystems.textsecuregcm.configuration.CdnConfiguration;
import org.whispersystems.textsecuregcm.controllers.ProfileController; import org.whispersystems.textsecuregcm.controllers.ProfileController;
@ -63,6 +65,7 @@ public class ProfileControllerTest {
when(profileAccount.isEnabled()).thenReturn(true); when(profileAccount.isEnabled()).thenReturn(true);
when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(profileAccount)); when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(AuthHelper.VALID_NUMBER_TWO)))).thenReturn(Optional.of(profileAccount));
} }
@ -78,7 +81,7 @@ public class ProfileControllerTest {
assertThat(profile.getName()).isEqualTo("baz"); assertThat(profile.getName()).isEqualTo("baz");
assertThat(profile.getAvatar()).isEqualTo("profiles/bang"); assertThat(profile.getAvatar()).isEqualTo("profiles/bang");
verify(accountsManager, times(1)).get(AuthHelper.VALID_NUMBER_TWO); verify(accountsManager, times(1)).get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(AuthHelper.VALID_NUMBER_TWO)));
verify(rateLimiters, times(1)).getProfileLimiter(); verify(rateLimiters, times(1)).getProfileLimiter();
verify(rateLimiter, times(1)).validate(eq(AuthHelper.VALID_NUMBER)); verify(rateLimiter, times(1)).validate(eq(AuthHelper.VALID_NUMBER));
} }

View File

@ -1,129 +0,0 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.collect.ImmutableSet;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.controllers.TransparentDataController;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PublicAccount;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import static junit.framework.TestCase.*;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
public class TransparentDataControllerTest {
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final Map<String, String> indexMap = new HashMap<>();
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new TransparentDataController(accountsManager, indexMap))
.build();
@Before
public void setup() {
Account accountOne = new Account("+14151231111", Collections.singleton(new Device(1, "foo", "bar", "salt", "keykey", "gcm-id", "apn-id", "voipapn-id", true, 1234, new SignedPreKey(5, "public-signed", "signtture-signed"), 31337, 31336, "CoolClient", true, 0)), new byte[16]);
Account accountTwo = new Account("+14151232222", Collections.singleton(new Device(1, "2foo", "2bar", "2salt", "2keykey", "2gcm-id", "2apn-id", "2voipapn-id", true, 1234, new SignedPreKey(5, "public-signed", "signtture-signed"), 31337, 31336, "CoolClient", true, 0)), new byte[16]);
accountOne.setProfileName("OneProfileName");
accountOne.setIdentityKey("identity_key_value");
accountTwo.setProfileName("TwoProfileName");
accountTwo.setIdentityKey("different_identity_key_value");
indexMap.put("1", "+14151231111");
indexMap.put("2", "+14151232222");
when(accountsManager.get(eq("+14151231111"))).thenReturn(Optional.of(accountOne));
when(accountsManager.get(eq("+14151232222"))).thenReturn(Optional.of(accountTwo));
}
@Test
public void testAccountOne() throws IOException {
Response response = resources.getJerseyTest()
.target(String.format("/v1/transparency/account/%s", "1"))
.request()
.get();
assertEquals(200, response.getStatus());
Account result = response.readEntity(PublicAccount.class);
assertTrue(result.getPin().isPresent());
assertEquals("******", result.getPin().get());
assertNull(result.getNumber());
assertEquals("OneProfileName", result.getProfileName());
assertThat("Account serialization works",
asJson(result),
is(equalTo(jsonFixture("fixtures/transparent_account.json"))));
verify(accountsManager, times(1)).get(eq("+14151231111"));
verifyNoMoreInteractions(accountsManager);
}
@Test
public void testAccountTwo() throws IOException {
Response response = resources.getJerseyTest()
.target(String.format("/v1/transparency/account/%s", "2"))
.request()
.get();
assertEquals(200, response.getStatus());
Account result = response.readEntity(PublicAccount.class);
assertTrue(result.getPin().isPresent());
assertEquals("******", result.getPin().get());
assertNull(result.getNumber());
assertEquals("TwoProfileName", result.getProfileName());
assertThat("Account serialization works 2",
asJson(result),
is(equalTo(jsonFixture("fixtures/transparent_account2.json"))));
verify(accountsManager, times(1)).get(eq("+14151232222"));
}
@Test
public void testAccountMissing() {
Response response = resources.getJerseyTest()
.target(String.format("/v1/transparency/account/%s", "3"))
.request()
.get();
assertEquals(404, response.getStatus());
verifyNoMoreInteractions(accountsManager);
}
}

View File

@ -1,4 +1,4 @@
/** /*
* Copyright (C) 2018 Open WhisperSystems * Copyright (C) 2018 Open WhisperSystems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
@ -17,21 +17,19 @@
package org.whispersystems.textsecuregcm.tests.storage; package org.whispersystems.textsecuregcm.tests.storage;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawler; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawler;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerListener; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerListener;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -42,8 +40,8 @@ import static org.mockito.Mockito.*;
public class AccountDatabaseCrawlerTest { public class AccountDatabaseCrawlerTest {
private static final String ACCOUNT1 = "+1"; private static final UUID ACCOUNT1 = UUID.randomUUID();
private static final String ACCOUNT2 = "+2"; private static final UUID ACCOUNT2 = UUID.randomUUID();
private static final int CHUNK_SIZE = 1000; private static final int CHUNK_SIZE = 1000;
private static final long CHUNK_INTERVAL_MS = 30_000L; private static final long CHUNK_INTERVAL_MS = 30_000L;
@ -59,8 +57,8 @@ public class AccountDatabaseCrawlerTest {
@Before @Before
public void setup() { public void setup() {
when(account1.getNumber()).thenReturn(ACCOUNT1); when(account1.getUuid()).thenReturn(ACCOUNT1);
when(account2.getNumber()).thenReturn(ACCOUNT2); when(account2.getUuid()).thenReturn(ACCOUNT2);
when(accounts.getAllFrom(anyInt())).thenReturn(Arrays.asList(account1, account2)); when(accounts.getAllFrom(anyInt())).thenReturn(Arrays.asList(account1, account2));
when(accounts.getAllFrom(eq(ACCOUNT1), anyInt())).thenReturn(Arrays.asList(account2)); when(accounts.getAllFrom(eq(ACCOUNT1), anyInt())).thenReturn(Arrays.asList(account2));
@ -72,20 +70,20 @@ public class AccountDatabaseCrawlerTest {
@Test @Test
public void testCrawlStart() throws AccountDatabaseCrawlerRestartException { public void testCrawlStart() throws AccountDatabaseCrawlerRestartException {
when(cache.getLastNumber()).thenReturn(Optional.empty()); when(cache.getLastUuid()).thenReturn(Optional.empty());
boolean accelerated = crawler.doPeriodicWork(); boolean accelerated = crawler.doPeriodicWork();
assertThat(accelerated).isFalse(); assertThat(accelerated).isFalse();
verify(cache, times(1)).claimActiveWork(any(String.class), anyLong()); verify(cache, times(1)).claimActiveWork(any(String.class), anyLong());
verify(cache, times(1)).getLastNumber(); verify(cache, times(1)).getLastUuid();
verify(listener, times(1)).onCrawlStart(); verify(listener, times(1)).onCrawlStart();
verify(accounts, times(1)).getAllFrom(eq(CHUNK_SIZE)); verify(accounts, times(1)).getAllFrom(eq(CHUNK_SIZE));
verify(accounts, times(0)).getAllFrom(any(String.class), eq(CHUNK_SIZE)); verify(accounts, times(0)).getAllFrom(any(UUID.class), eq(CHUNK_SIZE));
verify(account1, times(0)).getNumber(); verify(account1, times(0)).getUuid();
verify(account2, times(1)).getNumber(); verify(account2, times(1)).getUuid();
verify(listener, times(1)).onCrawlChunk(eq(Optional.empty()), eq(Arrays.asList(account1, account2))); verify(listener, times(1)).onCrawlChunk(eq(Optional.empty()), eq(Arrays.asList(account1, account2)));
verify(cache, times(1)).setLastNumber(eq(Optional.of(ACCOUNT2))); verify(cache, times(1)).setLastUuid(eq(Optional.of(ACCOUNT2)));
verify(cache, times(1)).isAccelerated(); verify(cache, times(1)).isAccelerated();
verify(cache, times(1)).releaseActiveWork(any(String.class)); verify(cache, times(1)).releaseActiveWork(any(String.class));
@ -98,18 +96,18 @@ public class AccountDatabaseCrawlerTest {
@Test @Test
public void testCrawlChunk() throws AccountDatabaseCrawlerRestartException { public void testCrawlChunk() throws AccountDatabaseCrawlerRestartException {
when(cache.getLastNumber()).thenReturn(Optional.of(ACCOUNT1)); when(cache.getLastUuid()).thenReturn(Optional.of(ACCOUNT1));
boolean accelerated = crawler.doPeriodicWork(); boolean accelerated = crawler.doPeriodicWork();
assertThat(accelerated).isFalse(); assertThat(accelerated).isFalse();
verify(cache, times(1)).claimActiveWork(any(String.class), anyLong()); verify(cache, times(1)).claimActiveWork(any(String.class), anyLong());
verify(cache, times(1)).getLastNumber(); verify(cache, times(1)).getLastUuid();
verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE)); verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE));
verify(accounts, times(1)).getAllFrom(eq(ACCOUNT1), eq(CHUNK_SIZE)); verify(accounts, times(1)).getAllFrom(eq(ACCOUNT1), eq(CHUNK_SIZE));
verify(account2, times(1)).getNumber(); verify(account2, times(1)).getUuid();
verify(listener, times(1)).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2))); verify(listener, times(1)).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2)));
verify(cache, times(1)).setLastNumber(eq(Optional.of(ACCOUNT2))); verify(cache, times(1)).setLastUuid(eq(Optional.of(ACCOUNT2)));
verify(cache, times(1)).isAccelerated(); verify(cache, times(1)).isAccelerated();
verify(cache, times(1)).releaseActiveWork(any(String.class)); verify(cache, times(1)).releaseActiveWork(any(String.class));
@ -124,18 +122,18 @@ public class AccountDatabaseCrawlerTest {
@Test @Test
public void testCrawlChunkAccelerated() throws AccountDatabaseCrawlerRestartException { public void testCrawlChunkAccelerated() throws AccountDatabaseCrawlerRestartException {
when(cache.isAccelerated()).thenReturn(true); when(cache.isAccelerated()).thenReturn(true);
when(cache.getLastNumber()).thenReturn(Optional.of(ACCOUNT1)); when(cache.getLastUuid()).thenReturn(Optional.of(ACCOUNT1));
boolean accelerated = crawler.doPeriodicWork(); boolean accelerated = crawler.doPeriodicWork();
assertThat(accelerated).isTrue(); assertThat(accelerated).isTrue();
verify(cache, times(1)).claimActiveWork(any(String.class), anyLong()); verify(cache, times(1)).claimActiveWork(any(String.class), anyLong());
verify(cache, times(1)).getLastNumber(); verify(cache, times(1)).getLastUuid();
verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE)); verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE));
verify(accounts, times(1)).getAllFrom(eq(ACCOUNT1), eq(CHUNK_SIZE)); verify(accounts, times(1)).getAllFrom(eq(ACCOUNT1), eq(CHUNK_SIZE));
verify(account2, times(1)).getNumber(); verify(account2, times(1)).getUuid();
verify(listener, times(1)).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2))); verify(listener, times(1)).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2)));
verify(cache, times(1)).setLastNumber(eq(Optional.of(ACCOUNT2))); verify(cache, times(1)).setLastUuid(eq(Optional.of(ACCOUNT2)));
verify(cache, times(1)).isAccelerated(); verify(cache, times(1)).isAccelerated();
verify(cache, times(1)).releaseActiveWork(any(String.class)); verify(cache, times(1)).releaseActiveWork(any(String.class));
@ -149,19 +147,19 @@ public class AccountDatabaseCrawlerTest {
@Test @Test
public void testCrawlChunkRestart() throws AccountDatabaseCrawlerRestartException { public void testCrawlChunkRestart() throws AccountDatabaseCrawlerRestartException {
when(cache.getLastNumber()).thenReturn(Optional.of(ACCOUNT1)); when(cache.getLastUuid()).thenReturn(Optional.of(ACCOUNT1));
doThrow(AccountDatabaseCrawlerRestartException.class).when(listener).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2))); doThrow(AccountDatabaseCrawlerRestartException.class).when(listener).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2)));
boolean accelerated = crawler.doPeriodicWork(); boolean accelerated = crawler.doPeriodicWork();
assertThat(accelerated).isFalse(); assertThat(accelerated).isFalse();
verify(cache, times(1)).claimActiveWork(any(String.class), anyLong()); verify(cache, times(1)).claimActiveWork(any(String.class), anyLong());
verify(cache, times(1)).getLastNumber(); verify(cache, times(1)).getLastUuid();
verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE)); verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE));
verify(accounts, times(1)).getAllFrom(eq(ACCOUNT1), eq(CHUNK_SIZE)); verify(accounts, times(1)).getAllFrom(eq(ACCOUNT1), eq(CHUNK_SIZE));
verify(account2, times(0)).getNumber(); verify(account2, times(0)).getNumber();
verify(listener, times(1)).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2))); verify(listener, times(1)).onCrawlChunk(eq(Optional.of(ACCOUNT1)), eq(Arrays.asList(account2)));
verify(cache, times(1)).setLastNumber(eq(Optional.empty())); verify(cache, times(1)).setLastUuid(eq(Optional.empty()));
verify(cache, times(1)).clearAccelerate(); verify(cache, times(1)).clearAccelerate();
verify(cache, times(1)).isAccelerated(); verify(cache, times(1)).isAccelerated();
verify(cache, times(1)).releaseActiveWork(any(String.class)); verify(cache, times(1)).releaseActiveWork(any(String.class));
@ -176,19 +174,19 @@ public class AccountDatabaseCrawlerTest {
@Test @Test
public void testCrawlEnd() { public void testCrawlEnd() {
when(cache.getLastNumber()).thenReturn(Optional.of(ACCOUNT2)); when(cache.getLastUuid()).thenReturn(Optional.of(ACCOUNT2));
boolean accelerated = crawler.doPeriodicWork(); boolean accelerated = crawler.doPeriodicWork();
assertThat(accelerated).isFalse(); assertThat(accelerated).isFalse();
verify(cache, times(1)).claimActiveWork(any(String.class), anyLong()); verify(cache, times(1)).claimActiveWork(any(String.class), anyLong());
verify(cache, times(1)).getLastNumber(); verify(cache, times(1)).getLastUuid();
verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE)); verify(accounts, times(0)).getAllFrom(eq(CHUNK_SIZE));
verify(accounts, times(1)).getAllFrom(eq(ACCOUNT2), eq(CHUNK_SIZE)); verify(accounts, times(1)).getAllFrom(eq(ACCOUNT2), eq(CHUNK_SIZE));
verify(account1, times(0)).getNumber(); verify(account1, times(0)).getNumber();
verify(account2, times(0)).getNumber(); verify(account2, times(0)).getNumber();
verify(listener, times(1)).onCrawlEnd(eq(Optional.of(ACCOUNT2))); verify(listener, times(1)).onCrawlEnd(eq(Optional.of(ACCOUNT2)));
verify(cache, times(1)).setLastNumber(eq(Optional.empty())); verify(cache, times(1)).setLastUuid(eq(Optional.empty()));
verify(cache, times(1)).clearAccelerate(); verify(cache, times(1)).clearAccelerate();
verify(cache, times(1)).isAccelerated(); verify(cache, times(1)).isAccelerated();
verify(cache, times(1)).releaseActiveWork(any(String.class)); verify(cache, times(1)).releaseActiveWork(any(String.class));

View File

@ -7,6 +7,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import java.util.HashSet; import java.util.HashSet;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -47,21 +48,21 @@ public class AccountTest {
@Test @Test
public void testAccountActive() { public void testAccountActive() {
Account recentAccount = new Account("+14152222222", new HashSet<Device>() {{ Account recentAccount = new Account("+14152222222", UUID.randomUUID(), new HashSet<Device>() {{
add(recentMasterDevice); add(recentMasterDevice);
add(recentSecondaryDevice); add(recentSecondaryDevice);
}}, "1234".getBytes()); }}, "1234".getBytes());
assertTrue(recentAccount.isEnabled()); assertTrue(recentAccount.isEnabled());
Account oldSecondaryAccount = new Account("+14152222222", new HashSet<Device>() {{ Account oldSecondaryAccount = new Account("+14152222222", UUID.randomUUID(), new HashSet<Device>() {{
add(recentMasterDevice); add(recentMasterDevice);
add(agingSecondaryDevice); add(agingSecondaryDevice);
}}, "1234".getBytes()); }}, "1234".getBytes());
assertTrue(oldSecondaryAccount.isEnabled()); assertTrue(oldSecondaryAccount.isEnabled());
Account agingPrimaryAccount = new Account("+14152222222", new HashSet<Device>() {{ Account agingPrimaryAccount = new Account("+14152222222", UUID.randomUUID(), new HashSet<Device>() {{
add(oldMasterDevice); add(oldMasterDevice);
add(agingSecondaryDevice); add(agingSecondaryDevice);
}}, "1234".getBytes()); }}, "1234".getBytes());
@ -71,7 +72,7 @@ public class AccountTest {
@Test @Test
public void testAccountInactive() { public void testAccountInactive() {
Account oldPrimaryAccount = new Account("+14152222222", new HashSet<Device>() {{ Account oldPrimaryAccount = new Account("+14152222222", UUID.randomUUID(), new HashSet<Device>() {{
add(oldMasterDevice); add(oldMasterDevice);
add(oldSecondaryDevice); add(oldSecondaryDevice);
}}, "1234".getBytes()); }}, "1234".getBytes());

View File

@ -9,6 +9,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import java.util.HashSet; import java.util.HashSet;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import static junit.framework.TestCase.assertSame; import static junit.framework.TestCase.assertSame;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
@ -21,14 +22,17 @@ import redis.clients.jedis.exceptions.JedisException;
public class AccountsManagerTest { public class AccountsManagerTest {
@Test @Test
public void testGetAccountInCache() { public void testGetAccountByNumberInCache() {
ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class); ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class);
Jedis jedis = mock(Jedis.class ); Jedis jedis = mock(Jedis.class );
Accounts accounts = mock(Accounts.class ); Accounts accounts = mock(Accounts.class );
DirectoryManager directoryManager = mock(DirectoryManager.class ); DirectoryManager directoryManager = mock(DirectoryManager.class );
UUID uuid = UUID.randomUUID();
when(cacheClient.getReadResource()).thenReturn(jedis); when(cacheClient.getReadResource()).thenReturn(jedis);
when(jedis.get(eq("Account5+14152222222"))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}"); when(jedis.get(eq("AccountMap::+14152222222"))).thenReturn(uuid.toString());
when(jedis.get(eq("Account::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}");
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient); AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
Optional<Account> account = accountsManager.get("+14152222222"); Optional<Account> account = accountsManager.get("+14152222222");
@ -37,23 +41,52 @@ public class AccountsManagerTest {
assertEquals(account.get().getNumber(), "+14152222222"); assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(account.get().getProfileName(), "test"); assertEquals(account.get().getProfileName(), "test");
verify(jedis, times(1)).get(eq("Account5+14152222222")); verify(jedis, times(1)).get(eq("AccountMap::+14152222222"));
verify(jedis, times(1)).get(eq("Account::" + uuid.toString()));
verify(jedis, times(2)).close();
verifyNoMoreInteractions(jedis);
verifyNoMoreInteractions(accounts);
}
@Test
public void testGetAccountByUuidInCache() {
ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class);
Jedis jedis = mock(Jedis.class );
Accounts accounts = mock(Accounts.class );
DirectoryManager directoryManager = mock(DirectoryManager.class );
UUID uuid = UUID.randomUUID();
when(cacheClient.getReadResource()).thenReturn(jedis);
when(jedis.get(eq("Account::" + uuid.toString()))).thenReturn("{\"number\": \"+14152222222\", \"name\": \"test\"}");
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
Optional<Account> account = accountsManager.get(uuid);
assertTrue(account.isPresent());
assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(account.get().getUuid(), uuid);
assertEquals(account.get().getProfileName(), "test");
verify(jedis, times(1)).get(eq("Account::" + uuid.toString()));
verify(jedis, times(1)).close(); verify(jedis, times(1)).close();
verifyNoMoreInteractions(jedis); verifyNoMoreInteractions(jedis);
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
} }
@Test @Test
public void testGetAccountNotInCache() { public void testGetAccountByNumberNotInCache() {
ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class); ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class);
Jedis jedis = mock(Jedis.class ); Jedis jedis = mock(Jedis.class );
Accounts accounts = mock(Accounts.class ); Accounts accounts = mock(Accounts.class );
DirectoryManager directoryManager = mock(DirectoryManager.class ); DirectoryManager directoryManager = mock(DirectoryManager.class );
Account account = new Account("+14152222222", new HashSet<>(), new byte[16]); UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
when(cacheClient.getReadResource()).thenReturn(jedis); when(cacheClient.getReadResource()).thenReturn(jedis);
when(cacheClient.getWriteResource()).thenReturn(jedis); when(cacheClient.getWriteResource()).thenReturn(jedis);
when(jedis.get(eq("Account5+14152222222"))).thenReturn(null); when(jedis.get(eq("AccountMap::+14152222222"))).thenReturn(null);
when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account));
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient); AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
@ -62,8 +95,9 @@ public class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(jedis, times(1)).get(eq("Account5+14152222222")); verify(jedis, times(1)).get(eq("AccountMap::+14152222222"));
verify(jedis, times(1)).set(eq("Account5+14152222222"), anyString()); verify(jedis, times(1)).set(eq("AccountMap::+14152222222"), eq(uuid.toString()));
verify(jedis, times(1)).set(eq("Account::" + uuid.toString()), anyString());
verify(jedis, times(2)).close(); verify(jedis, times(2)).close();
verifyNoMoreInteractions(jedis); verifyNoMoreInteractions(jedis);
@ -72,16 +106,47 @@ public class AccountsManagerTest {
} }
@Test @Test
public void testGetAccountBrokenCache() { public void testGetAccountByUuidNotInCache() {
ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class); ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class);
Jedis jedis = mock(Jedis.class ); Jedis jedis = mock(Jedis.class );
Accounts accounts = mock(Accounts.class ); Accounts accounts = mock(Accounts.class );
DirectoryManager directoryManager = mock(DirectoryManager.class ); DirectoryManager directoryManager = mock(DirectoryManager.class );
Account account = new Account("+14152222222", new HashSet<>(), new byte[16]); UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
when(cacheClient.getReadResource()).thenReturn(jedis); when(cacheClient.getReadResource()).thenReturn(jedis);
when(cacheClient.getWriteResource()).thenReturn(jedis); when(cacheClient.getWriteResource()).thenReturn(jedis);
when(jedis.get(eq("Account5+14152222222"))).thenThrow(new JedisException("Connection lost!")); when(jedis.get(eq("Account::" + uuid))).thenReturn(null);
when(accounts.get(eq(uuid))).thenReturn(Optional.of(account));
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
Optional<Account> retrieved = accountsManager.get(uuid);
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(jedis, times(1)).get(eq("Account::" + uuid));
verify(jedis, times(1)).set(eq("AccountMap::+14152222222"), eq(uuid.toString()));
verify(jedis, times(1)).set(eq("Account::" + uuid.toString()), anyString());
verify(jedis, times(2)).close();
verifyNoMoreInteractions(jedis);
verify(accounts, times(1)).get(eq(uuid));
verifyNoMoreInteractions(accounts);
}
@Test
public void testGetAccountByNumberBrokenCache() {
ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class);
Jedis jedis = mock(Jedis.class );
Accounts accounts = mock(Accounts.class );
DirectoryManager directoryManager = mock(DirectoryManager.class );
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
when(cacheClient.getReadResource()).thenReturn(jedis);
when(cacheClient.getWriteResource()).thenReturn(jedis);
when(jedis.get(eq("AccountMap::+14152222222"))).thenThrow(new JedisException("Connection lost!"));
when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account)); when(accounts.get(eq("+14152222222"))).thenReturn(Optional.of(account));
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient); AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
@ -90,8 +155,9 @@ public class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(jedis, times(1)).get(eq("Account5+14152222222")); verify(jedis, times(1)).get(eq("AccountMap::+14152222222"));
verify(jedis, times(1)).set(eq("Account5+14152222222"), anyString()); verify(jedis, times(1)).set(eq("AccountMap::+14152222222"), eq(uuid.toString()));
verify(jedis, times(1)).set(eq("Account::" + uuid.toString()), anyString());
verify(jedis, times(2)).close(); verify(jedis, times(2)).close();
verifyNoMoreInteractions(jedis); verifyNoMoreInteractions(jedis);
@ -99,6 +165,35 @@ public class AccountsManagerTest {
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
} }
@Test
public void testGetAccountByUuidBrokenCache() {
ReplicatedJedisPool cacheClient = mock(ReplicatedJedisPool.class);
Jedis jedis = mock(Jedis.class );
Accounts accounts = mock(Accounts.class );
DirectoryManager directoryManager = mock(DirectoryManager.class );
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
when(cacheClient.getReadResource()).thenReturn(jedis);
when(cacheClient.getWriteResource()).thenReturn(jedis);
when(jedis.get(eq("Account::" + uuid))).thenThrow(new JedisException("Connection lost!"));
when(accounts.get(eq(uuid))).thenReturn(Optional.of(account));
AccountsManager accountsManager = new AccountsManager(accounts, directoryManager, cacheClient);
Optional<Account> retrieved = accountsManager.get(uuid);
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(jedis, times(1)).get(eq("Account::" + uuid));
verify(jedis, times(1)).set(eq("AccountMap::+14152222222"), eq(uuid.toString()));
verify(jedis, times(1)).set(eq("Account::" + uuid.toString()), anyString());
verify(jedis, times(2)).close();
verifyNoMoreInteractions(jedis);
verify(accounts, times(1)).get(eq(uuid));
verifyNoMoreInteractions(accounts);
}
} }

View File

@ -1,5 +1,6 @@
package org.whispersystems.textsecuregcm.tests.storage; package org.whispersystems.textsecuregcm.tests.storage;
import com.fasterxml.uuid.UUIDComparator;
import com.opentable.db.postgres.embedded.LiquibasePreparer; import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit.EmbeddedPostgresRules; import com.opentable.db.postgres.junit.EmbeddedPostgresRules;
import com.opentable.db.postgres.junit.PreparedDbRule; import com.opentable.db.postgres.junit.PreparedDbRule;
@ -17,6 +18,8 @@ import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.mappers.AccountRowMapper; import org.whispersystems.textsecuregcm.storage.mappers.AccountRowMapper;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.Util;
import java.io.IOException; import java.io.IOException;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
@ -25,11 +28,13 @@ import java.sql.SQLException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.UUID;
import io.github.resilience4j.circuitbreaker.CircuitBreakerOpenException; import io.github.resilience4j.circuitbreaker.CircuitBreakerOpenException;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@ -56,12 +61,12 @@ public class AccountsTest {
@Test @Test
public void testStore() throws SQLException, IOException { public void testStore() throws SQLException, IOException {
Device device = generateDevice (1 ); Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device));
accounts.create(account); accounts.create(account);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?"); PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?");
verifyStoredState(statement, "+14151112222", account); verifyStoredState(statement, "+14151112222", account.getUuid(), account);
} }
@Test @Test
@ -70,12 +75,12 @@ public class AccountsTest {
devices.add(generateDevice(1)); devices.add(generateDevice(1));
devices.add(generateDevice(2)); devices.add(generateDevice(2));
Account account = generateAccount("+14151112222", devices); Account account = generateAccount("+14151112222", UUID.randomUUID(), devices);
accounts.create(account); accounts.create(account);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?"); PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?");
verifyStoredState(statement, "+14151112222", account); verifyStoredState(statement, "+14151112222", account.getUuid(), account);
} }
@Test @Test
@ -84,13 +89,15 @@ public class AccountsTest {
devicesFirst.add(generateDevice(1)); devicesFirst.add(generateDevice(1));
devicesFirst.add(generateDevice(2)); devicesFirst.add(generateDevice(2));
Account accountFirst = generateAccount("+14151112222", devicesFirst); UUID uuidFirst = UUID.randomUUID();
Account accountFirst = generateAccount("+14151112222", uuidFirst, devicesFirst);
Set<Device> devicesSecond = new HashSet<>(); Set<Device> devicesSecond = new HashSet<>();
devicesSecond.add(generateDevice(1)); devicesSecond.add(generateDevice(1));
devicesSecond.add(generateDevice(2)); devicesSecond.add(generateDevice(2));
Account accountSecond = generateAccount("+14152221111", devicesSecond); UUID uuidSecond = UUID.randomUUID();
Account accountSecond = generateAccount("+14152221111", uuidSecond, devicesSecond);
accounts.create(accountFirst); accounts.create(accountFirst);
accounts.create(accountSecond); accounts.create(accountSecond);
@ -101,31 +108,43 @@ public class AccountsTest {
assertThat(retrievedFirst.isPresent()).isTrue(); assertThat(retrievedFirst.isPresent()).isTrue();
assertThat(retrievedSecond.isPresent()).isTrue(); assertThat(retrievedSecond.isPresent()).isTrue();
verifyStoredState("+14151112222", retrievedFirst.get(), accountFirst); verifyStoredState("+14151112222", uuidFirst, retrievedFirst.get(), accountFirst);
verifyStoredState("+14152221111", retrievedSecond.get(), accountSecond); verifyStoredState("+14152221111", uuidSecond, retrievedSecond.get(), accountSecond);
retrievedFirst = accounts.get(uuidFirst);
retrievedSecond = accounts.get(uuidSecond);
assertThat(retrievedFirst.isPresent()).isTrue();
assertThat(retrievedSecond.isPresent()).isTrue();
verifyStoredState("+14151112222", uuidFirst, retrievedFirst.get(), accountFirst);
verifyStoredState("+14152221111", uuidSecond, retrievedSecond.get(), accountSecond);
} }
@Test @Test
public void testOverwrite() throws Exception { public void testOverwrite() throws Exception {
Device device = generateDevice (1 ); Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device)); UUID firstUuid = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, Collections.singleton(device));
accounts.create(account); accounts.create(account);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?"); PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM accounts WHERE number = ?");
verifyStoredState(statement, "+14151112222", account); verifyStoredState(statement, "+14151112222", account.getUuid(), account);
UUID secondUuid = UUID.randomUUID();
device = generateDevice(1); device = generateDevice(1);
account = generateAccount("+14151112222", Collections.singleton(device)); account = generateAccount("+14151112222", secondUuid, Collections.singleton(device));
accounts.create(account); accounts.create(account);
verifyStoredState(statement, "+14151112222", account); verifyStoredState(statement, "+14151112222", firstUuid, account);
} }
@Test @Test
public void testUpdate() { public void testUpdate() {
Device device = generateDevice (1 ); Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device));
accounts.create(account); accounts.create(account);
@ -136,7 +155,12 @@ public class AccountsTest {
Optional<Account> retrieved = accounts.get("+14151112222"); Optional<Account> retrieved = accounts.get("+14151112222");
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
verifyStoredState("+14151112222", retrieved.get(), account); verifyStoredState("+14151112222", account.getUuid(), retrieved.get(), account);
retrieved = accounts.get(account.getUuid());
assertThat(retrieved.isPresent()).isTrue();
verifyStoredState("+14151112222", account.getUuid(), retrieved.get(), account);
} }
@Test @Test
@ -144,24 +168,26 @@ public class AccountsTest {
List<Account> users = new ArrayList<>(); List<Account> users = new ArrayList<>();
for (int i=1;i<=100;i++) { for (int i=1;i<=100;i++) {
Account account = generateAccount("+1" + String.format("%03d", i)); Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID());
users.add(account); users.add(account);
accounts.create(account); accounts.create(account);
} }
users.sort((account, t1) -> UUIDComparator.staticCompare(account.getUuid(), t1.getUuid()));
List<Account> retrieved = accounts.getAllFrom(10); List<Account> retrieved = accounts.getAllFrom(10);
assertThat(retrieved.size()).isEqualTo(10); assertThat(retrieved.size()).isEqualTo(10);
for (int i=0;i<retrieved.size();i++) { for (int i=0;i<retrieved.size();i++) {
verifyStoredState("+1" + String.format("%03d", (i + 1)), retrieved.get(i), users.get(i)); verifyStoredState(users.get(i).getNumber(), users.get(i).getUuid(), retrieved.get(i), users.get(i));
} }
for (int j=0;j<9;j++) { for (int j=0;j<9;j++) {
retrieved = accounts.getAllFrom(retrieved.get(9).getNumber(), 10); retrieved = accounts.getAllFrom(retrieved.get(9).getUuid(), 10);
assertThat(retrieved.size()).isEqualTo(10); assertThat(retrieved.size()).isEqualTo(10);
for (int i=0;i<retrieved.size();i++) { for (int i=0;i<retrieved.size();i++) {
verifyStoredState("+1" + String.format("%03d", (10 + (j * 10) + i + 1)), retrieved.get(i), users.get(10 + (j * 10) + i)); verifyStoredState(users.get(10 + (j * 10) + i).getNumber(), users.get(10 + (j * 10) + i).getUuid(), retrieved.get(i), users.get(10 + (j * 10) + i));
} }
} }
} }
@ -169,7 +195,7 @@ public class AccountsTest {
@Test @Test
public void testVacuum() { public void testVacuum() {
Device device = generateDevice (1 ); Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device));
accounts.create(account); accounts.create(account);
accounts.vacuum(); accounts.vacuum();
@ -177,18 +203,21 @@ public class AccountsTest {
Optional<Account> retrieved = accounts.get("+14151112222"); Optional<Account> retrieved = accounts.get("+14151112222");
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
verifyStoredState("+14151112222", retrieved.get(), account); verifyStoredState("+14151112222", account.getUuid(), retrieved.get(), account);
} }
@Test @Test
public void testMissing() { public void testMissing() {
Device device = generateDevice (1 ); Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device));
accounts.create(account); accounts.create(account);
Optional<Account> retrieved = accounts.get("+11111111"); Optional<Account> retrieved = accounts.get("+11111111");
assertThat(retrieved.isPresent()).isFalse(); assertThat(retrieved.isPresent()).isFalse();
retrieved = accounts.get(UUID.randomUUID());
assertThat(retrieved.isPresent()).isFalse();
} }
@Test @Test
@ -203,7 +232,7 @@ public class AccountsTest {
configuration.setFailureRateThreshold(50); configuration.setFailureRateThreshold(50);
Accounts accounts = new Accounts(new FaultTolerantDatabase("testAccountBreaker", jdbi, configuration)); Accounts accounts = new Accounts(new FaultTolerantDatabase("testAccountBreaker", jdbi, configuration));
Account account = generateAccount("+14151112222"); Account account = generateAccount("+14151112222", UUID.randomUUID());
try { try {
accounts.update(account); accounts.update(account);
@ -244,20 +273,20 @@ public class AccountsTest {
return new Device(id, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(), null, "testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt(), random.nextBoolean(), 0); return new Device(id, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(), null, "testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt(), random.nextBoolean(), 0);
} }
private Account generateAccount(String number) { private Account generateAccount(String number, UUID uuid) {
Device device = generateDevice(1); Device device = generateDevice(1);
return generateAccount(number, Collections.singleton(device)); return generateAccount(number, uuid, Collections.singleton(device));
} }
private Account generateAccount(String number, Set<Device> devices) { private Account generateAccount(String number, UUID uuid, Set<Device> devices) {
byte[] unidentifiedAccessKey = new byte[16]; byte[] unidentifiedAccessKey = new byte[16];
Random random = new Random(System.currentTimeMillis()); Random random = new Random(System.currentTimeMillis());
Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255)); Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255));
return new Account(number, devices, unidentifiedAccessKey); return new Account(number, uuid, devices, unidentifiedAccessKey);
} }
private void verifyStoredState(PreparedStatement statement, String number, Account expecting) private void verifyStoredState(PreparedStatement statement, String number, UUID uuid, Account expecting)
throws SQLException, IOException throws SQLException, IOException
{ {
statement.setString(1, number); statement.setString(1, number);
@ -269,7 +298,7 @@ public class AccountsTest {
assertThat(data).isNotEmpty(); assertThat(data).isNotEmpty();
Account result = new AccountRowMapper().map(resultSet, null); Account result = new AccountRowMapper().map(resultSet, null);
verifyStoredState(number, result, expecting); verifyStoredState(number, uuid, result, expecting);
} else { } else {
throw new AssertionError("No data"); throw new AssertionError("No data");
} }
@ -277,9 +306,10 @@ public class AccountsTest {
assertThat(resultSet.next()).isFalse(); assertThat(resultSet.next()).isFalse();
} }
private void verifyStoredState(String number, Account result, Account expecting) { private void verifyStoredState(String number, UUID uuid, Account result, Account expecting) {
assertThat(result.getNumber()).isEqualTo(number); assertThat(result.getNumber()).isEqualTo(number);
assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen()); assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen());
assertThat(result.getUuid()).isEqualTo(uuid);
assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue(); assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue();
for (Device expectingDevice : expecting.getDevices()) { for (Device expectingDevice : expecting.getDevices()) {

View File

@ -31,6 +31,7 @@ import org.junit.Test;
import redis.clients.jedis.Jedis; import redis.clients.jedis.Jedis;
import java.util.Arrays; import java.util.Arrays;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.Optional; import java.util.Optional;
@ -46,9 +47,13 @@ import static org.mockito.Mockito.when;
public class ActiveUserCounterTest { public class ActiveUserCounterTest {
private final String NUMBER_IOS = "+15551234567"; private final UUID UUID_IOS = UUID.randomUUID();
private final String NUMBER_ANDROID = "+5511987654321"; private final UUID UUID_ANDROID = UUID.randomUUID();
private final String NUMBER_NODEVICE = "+5215551234567"; private final UUID UUID_NODEVICE = UUID.randomUUID();
private final String ACCOUNT_NUMBER_IOS = "+15551234567";
private final String ACCOUNT_NUMBER_ANDROID = "+5511987654321";
private final String ACCOUNT_NUMBER_NODEVICE = "+5215551234567";
private final String TALLY_KEY = "active_user_tally"; private final String TALLY_KEY = "active_user_tally";
@ -79,14 +84,17 @@ public class ActiveUserCounterTest {
when(iosDevice.getGcmId()).thenReturn(null); when(iosDevice.getGcmId()).thenReturn(null);
when(iosDevice.getLastSeen()).thenReturn(halfDayAgo); when(iosDevice.getLastSeen()).thenReturn(halfDayAgo);
when(iosAccount.getNumber()).thenReturn(NUMBER_IOS); when(iosAccount.getUuid()).thenReturn(UUID_IOS);
when(iosAccount.getMasterDevice()).thenReturn(Optional.of(iosDevice)); when(iosAccount.getMasterDevice()).thenReturn(Optional.of(iosDevice));
when(iosAccount.getNumber()).thenReturn(ACCOUNT_NUMBER_IOS);
when(androidAccount.getNumber()).thenReturn(NUMBER_ANDROID); when(androidAccount.getUuid()).thenReturn(UUID_ANDROID);
when(androidAccount.getMasterDevice()).thenReturn(Optional.of(androidDevice)); when(androidAccount.getMasterDevice()).thenReturn(Optional.of(androidDevice));
when(androidAccount.getNumber()).thenReturn(ACCOUNT_NUMBER_ANDROID);
when(noDeviceAccount.getNumber()).thenReturn(NUMBER_NODEVICE); when(noDeviceAccount.getUuid()).thenReturn(UUID_NODEVICE);
when(noDeviceAccount.getMasterDevice()).thenReturn(Optional.ofNullable(null)); when(noDeviceAccount.getMasterDevice()).thenReturn(Optional.ofNullable(null));
when(noDeviceAccount.getNumber()).thenReturn(ACCOUNT_NUMBER_NODEVICE);
when(jedis.get(any(String.class))).thenReturn("{\"fromNumber\":\"+\",\"platforms\":{},\"countries\":{}}"); when(jedis.get(any(String.class))).thenReturn("{\"fromNumber\":\"+\",\"platforms\":{},\"countries\":{}}");
when(jedisPool.getWriteResource()).thenReturn(jedis); when(jedisPool.getWriteResource()).thenReturn(jedis);
@ -137,7 +145,7 @@ public class ActiveUserCounterTest {
@Test @Test
public void testCrawlChunkValidAccount() throws AccountDatabaseCrawlerRestartException { public void testCrawlChunkValidAccount() throws AccountDatabaseCrawlerRestartException {
activeUserCounter.onCrawlChunk(Optional.of(NUMBER_IOS), Arrays.asList(iosAccount)); activeUserCounter.onCrawlChunk(Optional.of(UUID_IOS), Arrays.asList(iosAccount));
verify(iosAccount, times(1)).getMasterDevice(); verify(iosAccount, times(1)).getMasterDevice();
verify(iosAccount, times(1)).getNumber(); verify(iosAccount, times(1)).getNumber();
@ -148,7 +156,7 @@ public class ActiveUserCounterTest {
verify(jedisPool, times(1)).getWriteResource(); verify(jedisPool, times(1)).getWriteResource();
verify(jedis, times(1)).get(any(String.class)); verify(jedis, times(1)).get(any(String.class));
verify(jedis, times(1)).set(any(String.class), eq("{\"fromNumber\":\""+NUMBER_IOS+"\",\"platforms\":{\"ios\":[1,1,1,1,1]},\"countries\":{\"1\":[1,1,1,1,1]}}")); verify(jedis, times(1)).set(any(String.class), eq("{\"fromUuid\":\""+UUID_IOS.toString()+"\",\"platforms\":{\"ios\":[1,1,1,1,1]},\"countries\":{\"1\":[1,1,1,1,1]}}"));
verify(jedis, times(1)).close(); verify(jedis, times(1)).close();
verify(metricsFactory, times(0)).getReporters(); verify(metricsFactory, times(0)).getReporters();
@ -166,13 +174,13 @@ public class ActiveUserCounterTest {
@Test @Test
public void testCrawlChunkNoDeviceAccount() throws AccountDatabaseCrawlerRestartException { public void testCrawlChunkNoDeviceAccount() throws AccountDatabaseCrawlerRestartException {
activeUserCounter.onCrawlChunk(Optional.of(NUMBER_NODEVICE), Arrays.asList(noDeviceAccount)); activeUserCounter.onCrawlChunk(Optional.of(UUID_NODEVICE), Arrays.asList(noDeviceAccount));
verify(noDeviceAccount, times(1)).getMasterDevice(); verify(noDeviceAccount, times(1)).getMasterDevice();
verify(jedisPool, times(1)).getWriteResource(); verify(jedisPool, times(1)).getWriteResource();
verify(jedis, times(1)).get(eq(TALLY_KEY)); verify(jedis, times(1)).get(eq(TALLY_KEY));
verify(jedis, times(1)).set(any(String.class), eq("{\"fromNumber\":\""+NUMBER_NODEVICE+"\",\"platforms\":{},\"countries\":{}}")); verify(jedis, times(1)).set(any(String.class), eq("{\"fromUuid\":\""+UUID_NODEVICE+"\",\"platforms\":{},\"countries\":{}}"));
verify(jedis, times(1)).close(); verify(jedis, times(1)).close();
verify(metricsFactory, times(0)).getReporters(); verify(metricsFactory, times(0)).getReporters();
@ -190,7 +198,7 @@ public class ActiveUserCounterTest {
@Test @Test
public void testCrawlChunkMixedAccount() throws AccountDatabaseCrawlerRestartException { public void testCrawlChunkMixedAccount() throws AccountDatabaseCrawlerRestartException {
activeUserCounter.onCrawlChunk(Optional.of(NUMBER_IOS), Arrays.asList(iosAccount, androidAccount, noDeviceAccount)); activeUserCounter.onCrawlChunk(Optional.of(UUID_IOS), Arrays.asList(iosAccount, androidAccount, noDeviceAccount));
verify(iosAccount, times(1)).getMasterDevice(); verify(iosAccount, times(1)).getMasterDevice();
verify(iosAccount, times(1)).getNumber(); verify(iosAccount, times(1)).getNumber();
@ -208,7 +216,7 @@ public class ActiveUserCounterTest {
verify(jedisPool, times(1)).getWriteResource(); verify(jedisPool, times(1)).getWriteResource();
verify(jedis, times(1)).get(eq(TALLY_KEY)); verify(jedis, times(1)).get(eq(TALLY_KEY));
verify(jedis, times(1)).set(any(String.class), eq("{\"fromNumber\":\""+NUMBER_IOS+"\",\"platforms\":{\"android\":[0,0,0,1,1],\"ios\":[1,1,1,1,1]},\"countries\":{\"55\":[0,0,0,1,1],\"1\":[1,1,1,1,1]}}")); verify(jedis, times(1)).set(any(String.class), eq("{\"fromUuid\":\""+UUID_IOS+"\",\"platforms\":{\"android\":[0,0,0,1,1],\"ios\":[1,1,1,1,1]},\"countries\":{\"55\":[0,0,0,1,1],\"1\":[1,1,1,1,1]}}"));
verify(jedis, times(1)).close(); verify(jedis, times(1)).close();
verify(metricsFactory, times(0)).getReporters(); verify(metricsFactory, times(0)).getReporters();

View File

@ -17,23 +17,23 @@
package org.whispersystems.textsecuregcm.tests.storage; package org.whispersystems.textsecuregcm.tests.storage;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationRequest; import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationRequest;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationResponse; import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationResponse;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
import org.whispersystems.textsecuregcm.storage.DirectoryManager.BatchOperationHandle;
import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager.BatchOperationHandle;
import org.whispersystems.textsecuregcm.storage.DirectoryReconciler; import org.whispersystems.textsecuregcm.storage.DirectoryReconciler;
import org.whispersystems.textsecuregcm.storage.DirectoryReconciliationClient; import org.whispersystems.textsecuregcm.storage.DirectoryReconciliationClient;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import java.util.Arrays; import java.util.Arrays;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -41,8 +41,10 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
public class DirectoryReconcilerTest { public class DirectoryReconcilerTest {
private static final String VALID_NUMBER = "valid"; private static final UUID VALID_UUID = UUID.randomUUID();
private static final String INACTIVE_NUMBER = "inactive"; private static final String VALID_NUMBERRR = "+14152222222";
private static final UUID INACTIVE_UUID = UUID.randomUUID();
private static final String INACTIVE_NUMBERRR = "+14151111111";
private final Account activeAccount = mock(Account.class); private final Account activeAccount = mock(Account.class);
private final Account inactiveAccount = mock(Account.class); private final Account inactiveAccount = mock(Account.class);
@ -56,9 +58,11 @@ public class DirectoryReconcilerTest {
@Before @Before
public void setup() { public void setup() {
when(activeAccount.getNumber()).thenReturn(VALID_NUMBER); when(activeAccount.getUuid()).thenReturn(VALID_UUID);
when(activeAccount.isEnabled()).thenReturn(true); when(activeAccount.isEnabled()).thenReturn(true);
when(inactiveAccount.getNumber()).thenReturn(INACTIVE_NUMBER); when(activeAccount.getNumber()).thenReturn(VALID_NUMBERRR);
when(inactiveAccount.getUuid()).thenReturn(INACTIVE_UUID);
when(inactiveAccount.getNumber()).thenReturn(INACTIVE_NUMBERRR);
when(inactiveAccount.isEnabled()).thenReturn(false); when(inactiveAccount.isEnabled()).thenReturn(false);
when(directoryManager.startBatchOperation()).thenReturn(batchOperationHandle); when(directoryManager.startBatchOperation()).thenReturn(batchOperationHandle);
} }
@ -66,27 +70,28 @@ public class DirectoryReconcilerTest {
@Test @Test
public void testCrawlChunkValid() throws AccountDatabaseCrawlerRestartException { public void testCrawlChunkValid() throws AccountDatabaseCrawlerRestartException {
when(reconciliationClient.sendChunk(any())).thenReturn(successResponse); when(reconciliationClient.sendChunk(any())).thenReturn(successResponse);
directoryReconciler.onCrawlChunk(Optional.of(VALID_NUMBER), Arrays.asList(activeAccount, inactiveAccount)); directoryReconciler.onCrawlChunk(Optional.of(VALID_UUID), Arrays.asList(activeAccount, inactiveAccount));
verify(activeAccount, times(2)).getNumber(); verify(activeAccount, times(2)).getNumber();
verify(activeAccount, times(2)).isEnabled(); verify(activeAccount, times(2)).isEnabled();
verify(inactiveAccount, times(2)).getNumber(); verify(inactiveAccount, times(1)).getUuid();
verify(inactiveAccount, times(1)).getNumber();
verify(inactiveAccount, times(2)).isEnabled(); verify(inactiveAccount, times(2)).isEnabled();
ArgumentCaptor<DirectoryReconciliationRequest> request = ArgumentCaptor.forClass(DirectoryReconciliationRequest.class); ArgumentCaptor<DirectoryReconciliationRequest> request = ArgumentCaptor.forClass(DirectoryReconciliationRequest.class);
verify(reconciliationClient, times(1)).sendChunk(request.capture()); verify(reconciliationClient, times(1)).sendChunk(request.capture());
assertThat(request.getValue().getFromNumber()).isEqualTo(VALID_NUMBER); assertThat(request.getValue().getFromUuid()).isEqualTo(VALID_UUID);
assertThat(request.getValue().getToNumber()).isEqualTo(INACTIVE_NUMBER); assertThat(request.getValue().getToUuid()).isEqualTo(INACTIVE_UUID);
assertThat(request.getValue().getNumbers()).isEqualTo(Arrays.asList(VALID_NUMBER)); assertThat(request.getValue().getNumbers()).isEqualTo(Arrays.asList(VALID_NUMBERRR));
ArgumentCaptor<ClientContact> addedContact = ArgumentCaptor.forClass(ClientContact.class); ArgumentCaptor<ClientContact> addedContact = ArgumentCaptor.forClass(ClientContact.class);
verify(directoryManager, times(1)).startBatchOperation(); verify(directoryManager, times(1)).startBatchOperation();
verify(directoryManager, times(1)).add(eq(batchOperationHandle), addedContact.capture()); verify(directoryManager, times(1)).add(eq(batchOperationHandle), addedContact.capture());
verify(directoryManager, times(1)).remove(eq(batchOperationHandle), eq(INACTIVE_NUMBER)); verify(directoryManager, times(1)).remove(eq(batchOperationHandle), eq(INACTIVE_NUMBERRR));
verify(directoryManager, times(1)).stopBatchOperation(eq(batchOperationHandle)); verify(directoryManager, times(1)).stopBatchOperation(eq(batchOperationHandle));
assertThat(addedContact.getValue().getToken()).isEqualTo(Util.getContactToken(VALID_NUMBER)); assertThat(addedContact.getValue().getToken()).isEqualTo(Util.getContactToken(VALID_NUMBERRR));
verifyNoMoreInteractions(activeAccount); verifyNoMoreInteractions(activeAccount);
verifyNoMoreInteractions(inactiveAccount); verifyNoMoreInteractions(inactiveAccount);

View File

@ -1,38 +0,0 @@
package org.whispersystems.textsecuregcm.tests.storage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PublicAccount;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
import java.util.Collections;
import java.util.Set;
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertNull;
public class PublicAccountTest {
@Test
public void testPinSanitation() throws IOException {
Set<Device> devices = Collections.singleton(new Device(1, "foo", "bar", "12345", null, "gcm-1234", null, null, true, 1234, new SignedPreKey(1, "public-foo", "signature-foo"), 31337, 31336, "Android4Life", true, 0));
Account account = new Account("+14151231234", devices, new byte[16]);
account.setPin("123456");
PublicAccount publicAccount = new PublicAccount(account);
String serialized = SystemMapper.getMapper().writeValueAsString(publicAccount);
JsonNode result = SystemMapper.getMapper().readTree(serialized);
assertEquals("******", result.get("pin").textValue());
assertNull(result.get("number"));
}
}

View File

@ -13,6 +13,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
@ -62,7 +63,7 @@ public class PushFeedbackProcessorTest {
@Test @Test
public void testEmpty() { public void testEmpty() {
PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue); PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue);
processor.onCrawlChunk(Optional.of("+14152222222"), Collections.emptyList()); processor.onCrawlChunk(Optional.of(UUID.randomUUID()), Collections.emptyList());
verifyZeroInteractions(accountsManager); verifyZeroInteractions(accountsManager);
verifyZeroInteractions(directoryQueue); verifyZeroInteractions(directoryQueue);
@ -71,7 +72,7 @@ public class PushFeedbackProcessorTest {
@Test @Test
public void testUpdate() { public void testUpdate() {
PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue); PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue);
processor.onCrawlChunk(Optional.of("+14153333333"), List.of(uninstalledAccount, mixedAccount, stillActiveAccount, freshAccount, cleanAccount)); processor.onCrawlChunk(Optional.of(UUID.randomUUID()), List.of(uninstalledAccount, mixedAccount, stillActiveAccount, freshAccount, cleanAccount));
verify(uninstalledDevice).setApnId(isNull()); verify(uninstalledDevice).setApnId(isNull());
verify(uninstalledDevice).setGcmId(isNull()); verify(uninstalledDevice).setGcmId(isNull());

View File

@ -1,7 +1,9 @@
package org.whispersystems.textsecuregcm.tests.util; package org.whispersystems.textsecuregcm.tests.util;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
@ -11,26 +13,32 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import io.dropwizard.auth.AuthFilter; import io.dropwizard.auth.AuthFilter;
import io.dropwizard.auth.PolymorphicAuthDynamicFeature; import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter; import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.auth.basic.BasicCredentials;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class AuthHelper { public class AuthHelper {
public static final String VALID_NUMBER = "+14150000000"; public static final String VALID_NUMBER = "+14150000000";
public static final UUID VALID_UUID = UUID.randomUUID();
public static final String VALID_PASSWORD = "foo"; public static final String VALID_PASSWORD = "foo";
public static final String VALID_NUMBER_TWO = "+201511111110"; public static final String VALID_NUMBER_TWO = "+201511111110";
public static final UUID VALID_UUID_TWO = UUID.randomUUID();
public static final String VALID_PASSWORD_TWO = "baz"; public static final String VALID_PASSWORD_TWO = "baz";
public static final String INVVALID_NUMBER = "+14151111111"; public static final String INVVALID_NUMBER = "+14151111111";
public static final UUID INVALID_UUID = UUID.randomUUID();
public static final String INVALID_PASSWORD = "bar"; public static final String INVALID_PASSWORD = "bar";
public static final String DISABLED_NUMBER = "+78888888"; public static final String DISABLED_NUMBER = "+78888888";
public static final UUID DISABLED_UUID = UUID.randomUUID();
public static final String DISABLED_PASSWORD = "poof"; public static final String DISABLED_PASSWORD = "poof";
public static final String VALID_IDENTITY = "BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo"; public static final String VALID_IDENTITY = "BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo";
@ -76,8 +84,11 @@ public class AuthHelper {
when(VALID_ACCOUNT_TWO.getEnabledDeviceCount()).thenReturn(6); when(VALID_ACCOUNT_TWO.getEnabledDeviceCount()).thenReturn(6);
when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER); when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
when(VALID_ACCOUNT_TWO.getNumber()).thenReturn(VALID_NUMBER_TWO); when(VALID_ACCOUNT_TWO.getNumber()).thenReturn(VALID_NUMBER_TWO);
when(VALID_ACCOUNT_TWO.getUuid()).thenReturn(VALID_UUID_TWO);
when(DISABLED_ACCOUNT.getNumber()).thenReturn(DISABLED_NUMBER); when(DISABLED_ACCOUNT.getNumber()).thenReturn(DISABLED_NUMBER);
when(DISABLED_ACCOUNT.getUuid()).thenReturn(DISABLED_UUID);
when(VALID_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE)); when(VALID_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT_TWO.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO)); when(VALID_ACCOUNT_TWO.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO));
@ -91,9 +102,21 @@ public class AuthHelper {
when(DISABLED_ACCOUNT.isEnabled()).thenReturn(false); when(DISABLED_ACCOUNT.isEnabled()).thenReturn(false);
when(VALID_ACCOUNT.getIdentityKey()).thenReturn(VALID_IDENTITY); when(VALID_ACCOUNT.getIdentityKey()).thenReturn(VALID_IDENTITY);
when(ACCOUNTS_MANAGER.get(VALID_NUMBER)).thenReturn(Optional.of(VALID_ACCOUNT)); when(ACCOUNTS_MANAGER.get(VALID_NUMBER)).thenReturn(Optional.of(VALID_ACCOUNT));
when(ACCOUNTS_MANAGER.get(VALID_UUID)).thenReturn(Optional.of(VALID_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(VALID_NUMBER)))).thenReturn(Optional.of(VALID_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(VALID_UUID)))).thenReturn(Optional.of(VALID_ACCOUNT));
when(ACCOUNTS_MANAGER.get(VALID_NUMBER_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); when(ACCOUNTS_MANAGER.get(VALID_NUMBER_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(VALID_UUID_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(VALID_NUMBER_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(VALID_UUID_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(DISABLED_NUMBER)).thenReturn(Optional.of(DISABLED_ACCOUNT)); when(ACCOUNTS_MANAGER.get(DISABLED_NUMBER)).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(DISABLED_UUID)).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter (); AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter ();
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter(); AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();

View File

@ -105,10 +105,13 @@ public class WebSocketConnectionTest {
public void testOpen() throws Exception { public void testOpen() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class); MessagesManager storedMessages = mock(MessagesManager.class);
UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{ List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage(1L, false, "sender1", 1111, false, "first")); add(createMessage(1L, false, "sender1", senderOneUuid, 1111, false, "first"));
add(createMessage(2L, false, "sender1", 2222, false, "second")); add(createMessage(2L, false, "sender1", senderOneUuid, 2222, false, "second"));
add(createMessage(3L, false, "sender2", 3333, false, "third")); add(createMessage(3L, false, "sender2", senderTwoUuid, 3333, false, "third"));
}}; }};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false); OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
@ -121,7 +124,7 @@ public class WebSocketConnectionTest {
final Device sender1device = mock(Device.class); final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<Device>() {{ Set<Device> sender1devices = new HashSet<>() {{
add(sender1device); add(sender1device);
}}; }};
@ -275,6 +278,7 @@ public class WebSocketConnectionTest {
final Envelope firstMessage = Envelope.newBuilder() final Envelope firstMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("first".getBytes())) .setLegacyMessage(ByteString.copyFrom("first".getBytes()))
.setSource("sender1") .setSource("sender1")
.setSourceUuid(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis()) .setTimestamp(System.currentTimeMillis())
.setSourceDevice(1) .setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT) .setType(Envelope.Type.CIPHERTEXT)
@ -283,6 +287,7 @@ public class WebSocketConnectionTest {
final Envelope secondMessage = Envelope.newBuilder() final Envelope secondMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("second".getBytes())) .setLegacyMessage(ByteString.copyFrom("second".getBytes()))
.setSource("sender2") .setSource("sender2")
.setSourceUuid(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis()) .setTimestamp(System.currentTimeMillis())
.setSourceDevice(2) .setSourceDevice(2)
.setType(Envelope.Type.CIPHERTEXT) .setType(Envelope.Type.CIPHERTEXT)
@ -290,11 +295,11 @@ public class WebSocketConnectionTest {
List<OutgoingMessageEntity> pendingMessages = new LinkedList<OutgoingMessageEntity>() {{ List<OutgoingMessageEntity> pendingMessages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(1, true, UUID.randomUUID(), firstMessage.getType().getNumber(), firstMessage.getRelay(), add(new OutgoingMessageEntity(1, true, UUID.randomUUID(), firstMessage.getType().getNumber(), firstMessage.getRelay(),
firstMessage.getTimestamp(), firstMessage.getSource(), firstMessage.getTimestamp(), firstMessage.getSource(), UUID.fromString(firstMessage.getSourceUuid()),
firstMessage.getSourceDevice(), firstMessage.getLegacyMessage().toByteArray(), firstMessage.getSourceDevice(), firstMessage.getLegacyMessage().toByteArray(),
firstMessage.getContent().toByteArray(), 0)); firstMessage.getContent().toByteArray(), 0));
add(new OutgoingMessageEntity(2, false, UUID.randomUUID(), secondMessage.getType().getNumber(), secondMessage.getRelay(), add(new OutgoingMessageEntity(2, false, UUID.randomUUID(), secondMessage.getType().getNumber(), secondMessage.getRelay(),
secondMessage.getTimestamp(), secondMessage.getSource(), secondMessage.getTimestamp(), secondMessage.getSource(), UUID.fromString(secondMessage.getSourceUuid()),
secondMessage.getSourceDevice(), secondMessage.getLegacyMessage().toByteArray(), secondMessage.getSourceDevice(), secondMessage.getLegacyMessage().toByteArray(),
secondMessage.getContent().toByteArray(), 0)); secondMessage.getContent().toByteArray(), 0));
}}; }};
@ -359,9 +364,9 @@ public class WebSocketConnectionTest {
} }
private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, long timestamp, boolean receipt, String content) { private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp, sender, 1, content.getBytes(), null, 0); null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);
} }
} }