Fix for PubSub channel.

1) Create channels based on numbers rather than DB row ids.

2) Ensure that stored messages are cleared at reregistration
   time.
This commit is contained in:
Moxie Marlinspike 2014-07-26 20:41:25 -07:00
parent 4eb88a3e02
commit c9a1386a55
12 changed files with 77 additions and 91 deletions

View File

@ -172,7 +172,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
deviceAuthenticator,
Device.class, "WhisperServer"));
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender));
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender, storedMessages));
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters));
environment.jersey().register(new DirectoryController(rateLimiters, directory));
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));

View File

@ -34,8 +34,10 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
@ -65,16 +67,19 @@ public class AccountController {
private final AccountsManager accounts;
private final RateLimiters rateLimiters;
private final SmsSender smsSender;
private final StoredMessages storedMessages;
public AccountController(PendingAccountsManager pendingAccounts,
AccountsManager accounts,
RateLimiters rateLimiters,
SmsSender smsSenderFactory)
SmsSender smsSenderFactory,
StoredMessages storedMessages)
{
this.pendingAccounts = pendingAccounts;
this.accounts = accounts;
this.rateLimiters = rateLimiters;
this.smsSender = smsSenderFactory;
this.storedMessages = storedMessages;
}
@Timed
@ -153,7 +158,7 @@ public class AccountController {
account.addDevice(device);
accounts.create(account);
storedMessages.clear(new WebsocketAddress(number, Device.MASTER_ID));
pendingAccounts.remove(number);
logger.debug("Stored device...");

View File

@ -94,7 +94,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
this.account = account.get();
this.device = account.get().getAuthenticatedDevice().get();
this.address = new WebsocketAddress(this.account.getId(), this.device.getId());
this.address = new WebsocketAddress(this.account.getNumber(), this.device.getId());
this.session = session;
this.session.setIdleTimeout(10 * 60 * 1000);
@ -148,7 +148,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
pushSender.sendMessage(account, device, remainingMessage);
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("onWebSocketClose", e);
storedMessages.insert(account.getId(), device.getId(), remainingMessage);
storedMessages.insert(address, remainingMessage);
}
}
}
@ -208,7 +208,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
}
private void handleQueryDatabase() {
List<PendingMessage> messages = storedMessages.getMessagesForDevice(account.getId(), device.getId());
List<PendingMessage> messages = storedMessages.getMessagesForDevice(address);
for (PendingMessage message : messages) {
handleDeliverOutgoingMessage(message);

View File

@ -103,16 +103,16 @@ public class APNSender implements Managed {
throws TransientPushFailureException
{
try {
String serializedPendingMessage = mapper.writeValueAsString(message);
String serializedPendingMessage = mapper.writeValueAsString(message);
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()),
new PubSubMessage(PubSubMessage.TYPE_DELIVER,
serializedPendingMessage)))
if (pubSubManager.publish(websocketAddress, new PubSubMessage(PubSubMessage.TYPE_DELIVER,
serializedPendingMessage)))
{
websocketMeter.mark();
} else {
memcacheSet(registrationId, account.getNumber());
storedMessages.insert(account.getId(), device.getId(), message);
storedMessages.insert(websocketAddress, message);
if (!message.isReceipt()) {
sendPush(registrationId, serializedPendingMessage);

View File

@ -57,14 +57,14 @@ public class WebsocketSender {
public void sendMessage(Account account, Device device, PendingMessage pendingMessage) {
try {
String serialized = mapper.writeValueAsString(pendingMessage);
WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized);
if (pubSubManager.publish(address, pubSubMessage)) {
onlineMeter.mark();
} else {
offlineMeter.mark();
storedMessages.insert(account.getId(), device.getId(), pendingMessage);
storedMessages.insert(address, pendingMessage);
pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
}
} catch (JsonProcessingException e) {

View File

@ -22,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
@ -30,9 +29,6 @@ public class Account {
public static final int MEMCACHE_VERION = 5;
@JsonIgnore
private long id;
@JsonProperty
private String number;
@ -57,14 +53,6 @@ public class Account {
this.devices = devices;
}
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
public Optional<Device> getAuthenticatedDevice() {
return authenticatedDevice;
}

View File

@ -92,7 +92,7 @@ public abstract class Accounts {
{
try {
Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
account.setId(resultSet.getLong(ID));
// account.setId(resultSet.getLong(ID));
return account;
} catch (IOException e) {

View File

@ -18,10 +18,12 @@ import redis.clients.jedis.JedisPubSub;
public class PubSubManager {
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final ObjectMapper mapper = SystemMapper.getMapper();
private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>();
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final ObjectMapper mapper = SystemMapper.getMapper();
private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<String, PubSubListener> listeners = new HashMap<>();
private final JedisPool jedisPool;
private boolean subscribed = false;
@ -33,25 +35,29 @@ public class PubSubManager {
}
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
listeners.put(address, listener);
baseListener.subscribe(address.toString());
listeners.put(address.serialize(), listener);
baseListener.subscribe(address.serialize());
}
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
if (listeners.get(address) == listener) {
listeners.remove(address);
baseListener.unsubscribe(address.toString());
if (listeners.get(address.serialize()) == listener) {
listeners.remove(address.serialize());
baseListener.unsubscribe(address.serialize());
}
}
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
return publish(address.serialize(), message);
}
private synchronized boolean publish(String channel, PubSubMessage message) {
try {
String serialized = mapper.writeValueAsString(message);
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
return jedis.publish(address.toString(), serialized) != 0;
return jedis.publish(channel, serialized) != 0;
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
@ -79,7 +85,7 @@ public class PubSubManager {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString());
jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
logger.warn("**** Unsubscribed from holding channel!!! ******");
} finally {
if (jedis != null)
@ -95,7 +101,7 @@ public class PubSubManager {
for (;;) {
try {
Thread.sleep(20000);
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo"));
publish(KEEPALIVE_CHANNEL, new PubSubMessage(0, "foo"));
} catch (InterruptedException e) {
throw new AssertionError(e);
}
@ -109,18 +115,15 @@ public class PubSubManager {
@Override
public void onMessage(String channel, String message) {
try {
WebsocketAddress address = new WebsocketAddress(channel);
PubSubListener listener;
synchronized (PubSubManager.this) {
listener = listeners.get(address);
listener = listeners.get(channel);
}
if (listener != null) {
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
}
} catch (InvalidWebsocketAddressException e) {
logger.warn("Address", e);
} catch (IOException e) {
logger.warn("IOE", e);
}
@ -133,17 +136,11 @@ public class PubSubManager {
@Override
public void onSubscribe(String channel, int count) {
try {
WebsocketAddress address = new WebsocketAddress(channel);
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
}
if (KEEPALIVE_CHANNEL.equals(channel)) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
}
} catch (InvalidWebsocketAddressException e) {
logger.warn("Weird address", e);
}
}

View File

@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.io.IOException;
import java.util.LinkedList;
@ -53,18 +54,30 @@ public class StoredMessages {
this.jedisPool = jedisPool;
}
public void insert(long accountId, long deviceId, PendingMessage message) {
public void clear(WebsocketAddress address) {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
jedis.del(getKey(address));
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
}
}
public void insert(WebsocketAddress address, PendingMessage message) {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
String serializedMessage = mapper.writeValueAsString(message);
long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage);
long queueSize = jedis.lpush(getKey(address), serializedMessage);
queueSizeHistogram.update(queueSize);
if (queueSize > 1000) {
jedis.ltrim(getKey(accountId, deviceId), 0, 999);
jedis.ltrim(getKey(address), 0, 999);
}
} catch (JsonProcessingException e) {
@ -75,7 +88,7 @@ public class StoredMessages {
}
}
public List<PendingMessage> getMessagesForDevice(long accountId, long deviceId) {
public List<PendingMessage> getMessagesForDevice(WebsocketAddress address) {
List<PendingMessage> messages = new LinkedList<>();
Jedis jedis = null;
@ -83,7 +96,7 @@ public class StoredMessages {
jedis = jedisPool.getResource();
String message;
while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) {
while ((message = jedis.rpop(getKey(address))) != null) {
try {
messages.add(mapper.readValue(message, PendingMessage.class));
} catch (IOException e) {
@ -98,8 +111,8 @@ public class StoredMessages {
}
}
private String getKey(long accountId, long deviceId) {
return QUEUE_PREFIX + ":" + accountId + ":" + deviceId;
private String getKey(WebsocketAddress address) {
return QUEUE_PREFIX + ":" + address.serialize();
}
}

View File

@ -2,39 +2,20 @@ package org.whispersystems.textsecuregcm.websocket;
public class WebsocketAddress {
private final long accountId;
private final long deviceId;
private final String number;
private final long deviceId;
public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException {
try {
String[] parts = serialized.split(":");
if (parts == null || parts.length != 2) {
throw new InvalidWebsocketAddressException(serialized);
}
this.accountId = Long.parseLong(parts[0]);
this.deviceId = Long.parseLong(parts[1]);
} catch (NumberFormatException e) {
throw new InvalidWebsocketAddressException(e);
}
}
public WebsocketAddress(long accountId, long deviceId) {
this.accountId = accountId;
public WebsocketAddress(String number, long deviceId) {
this.number = number;
this.deviceId = deviceId;
}
public long getAccountId() {
return accountId;
}
public long getDeviceId() {
return deviceId;
public String serialize() {
return number + ":" + deviceId;
}
public String toString() {
return accountId + ":" + deviceId;
return serialize();
}
@Override
@ -45,13 +26,13 @@ public class WebsocketAddress {
WebsocketAddress that = (WebsocketAddress)other;
return
this.accountId == that.accountId &&
this.number.equals(that.number) &&
this.deviceId == that.deviceId;
}
@Override
public int hashCode() {
return (int)accountId ^ (int)deviceId;
return number.hashCode() ^ (int)deviceId;
}
}

View File

@ -13,6 +13,7 @@ import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import javax.ws.rs.core.MediaType;
@ -31,6 +32,7 @@ public class AccountControllerTest {
private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class );
private SmsSender smsSender = mock(SmsSender.class );
private StoredMessages storedMessages = mock(StoredMessages.class );
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
@ -38,7 +40,8 @@ public class AccountControllerTest {
.addResource(new AccountController(pendingAccountsManager,
accountsManager,
rateLimiters,
smsSender))
smsSender,
storedMessages))
.build();

View File

@ -94,7 +94,6 @@ public class WebsocketControllerTest {
}};
when(device.getId()).thenReturn(2L);
when(account.getId()).thenReturn(31337L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(session.getRemote()).thenReturn(remote);
@ -120,7 +119,7 @@ public class WebsocketControllerTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
when(storedMessages.getMessagesForDevice(account.getId(), device.getId()))
when(storedMessages.getMessagesForDevice(new WebsocketAddress(account.getNumber(), device.getId())))
.thenReturn(outgoingMessages);
WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, accountsManager, pushSender, storedMessages, pubSubManager);
@ -128,7 +127,7 @@ public class WebsocketControllerTest {
controller.onWebSocketConnect(session);
verify(pubSubManager).subscribe(eq(new WebsocketAddress(31337L, 2L)), eq((controller)));
verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((controller)));
verify(remote, times(3)).sendStringByFuture(anyString());
controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1)));