Fix for PubSub channel.

1) Create channels based on numbers rather than DB row ids.

2) Ensure that stored messages are cleared at reregistration
   time.
This commit is contained in:
Moxie Marlinspike 2014-07-26 20:41:25 -07:00
parent 4eb88a3e02
commit c9a1386a55
12 changed files with 77 additions and 91 deletions

View File

@ -172,7 +172,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
deviceAuthenticator, deviceAuthenticator,
Device.class, "WhisperServer")); Device.class, "WhisperServer"));
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender)); environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender, storedMessages));
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters)); environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters));
environment.jersey().register(new DirectoryController(rateLimiters, directory)); environment.jersey().register(new DirectoryController(rateLimiters, directory));
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1)); environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));

View File

@ -34,8 +34,10 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -65,16 +67,19 @@ public class AccountController {
private final AccountsManager accounts; private final AccountsManager accounts;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final SmsSender smsSender; private final SmsSender smsSender;
private final StoredMessages storedMessages;
public AccountController(PendingAccountsManager pendingAccounts, public AccountController(PendingAccountsManager pendingAccounts,
AccountsManager accounts, AccountsManager accounts,
RateLimiters rateLimiters, RateLimiters rateLimiters,
SmsSender smsSenderFactory) SmsSender smsSenderFactory,
StoredMessages storedMessages)
{ {
this.pendingAccounts = pendingAccounts; this.pendingAccounts = pendingAccounts;
this.accounts = accounts; this.accounts = accounts;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.smsSender = smsSenderFactory; this.smsSender = smsSenderFactory;
this.storedMessages = storedMessages;
} }
@Timed @Timed
@ -153,7 +158,7 @@ public class AccountController {
account.addDevice(device); account.addDevice(device);
accounts.create(account); accounts.create(account);
storedMessages.clear(new WebsocketAddress(number, Device.MASTER_ID));
pendingAccounts.remove(number); pendingAccounts.remove(number);
logger.debug("Stored device..."); logger.debug("Stored device...");

View File

@ -94,7 +94,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
this.account = account.get(); this.account = account.get();
this.device = account.get().getAuthenticatedDevice().get(); this.device = account.get().getAuthenticatedDevice().get();
this.address = new WebsocketAddress(this.account.getId(), this.device.getId()); this.address = new WebsocketAddress(this.account.getNumber(), this.device.getId());
this.session = session; this.session = session;
this.session.setIdleTimeout(10 * 60 * 1000); this.session.setIdleTimeout(10 * 60 * 1000);
@ -148,7 +148,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
pushSender.sendMessage(account, device, remainingMessage); pushSender.sendMessage(account, device, remainingMessage);
} catch (NotPushRegisteredException | TransientPushFailureException e) { } catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("onWebSocketClose", e); logger.warn("onWebSocketClose", e);
storedMessages.insert(account.getId(), device.getId(), remainingMessage); storedMessages.insert(address, remainingMessage);
} }
} }
} }
@ -208,7 +208,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
} }
private void handleQueryDatabase() { private void handleQueryDatabase() {
List<PendingMessage> messages = storedMessages.getMessagesForDevice(account.getId(), device.getId()); List<PendingMessage> messages = storedMessages.getMessagesForDevice(address);
for (PendingMessage message : messages) { for (PendingMessage message : messages) {
handleDeliverOutgoingMessage(message); handleDeliverOutgoingMessage(message);

View File

@ -104,15 +104,15 @@ public class APNSender implements Managed {
{ {
try { try {
String serializedPendingMessage = mapper.writeValueAsString(message); String serializedPendingMessage = mapper.writeValueAsString(message);
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()), if (pubSubManager.publish(websocketAddress, new PubSubMessage(PubSubMessage.TYPE_DELIVER,
new PubSubMessage(PubSubMessage.TYPE_DELIVER,
serializedPendingMessage))) serializedPendingMessage)))
{ {
websocketMeter.mark(); websocketMeter.mark();
} else { } else {
memcacheSet(registrationId, account.getNumber()); memcacheSet(registrationId, account.getNumber());
storedMessages.insert(account.getId(), device.getId(), message); storedMessages.insert(websocketAddress, message);
if (!message.isReceipt()) { if (!message.isReceipt()) {
sendPush(registrationId, serializedPendingMessage); sendPush(registrationId, serializedPendingMessage);

View File

@ -57,14 +57,14 @@ public class WebsocketSender {
public void sendMessage(Account account, Device device, PendingMessage pendingMessage) { public void sendMessage(Account account, Device device, PendingMessage pendingMessage) {
try { try {
String serialized = mapper.writeValueAsString(pendingMessage); String serialized = mapper.writeValueAsString(pendingMessage);
WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId()); WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized); PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized);
if (pubSubManager.publish(address, pubSubMessage)) { if (pubSubManager.publish(address, pubSubMessage)) {
onlineMeter.mark(); onlineMeter.mark();
} else { } else {
offlineMeter.mark(); offlineMeter.mark();
storedMessages.insert(account.getId(), device.getId(), pendingMessage); storedMessages.insert(address, pendingMessage);
pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null)); pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
} }
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {

View File

@ -22,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import java.io.Serializable;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -30,9 +29,6 @@ public class Account {
public static final int MEMCACHE_VERION = 5; public static final int MEMCACHE_VERION = 5;
@JsonIgnore
private long id;
@JsonProperty @JsonProperty
private String number; private String number;
@ -57,14 +53,6 @@ public class Account {
this.devices = devices; this.devices = devices;
} }
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
public Optional<Device> getAuthenticatedDevice() { public Optional<Device> getAuthenticatedDevice() {
return authenticatedDevice; return authenticatedDevice;
} }

View File

@ -92,7 +92,7 @@ public abstract class Accounts {
{ {
try { try {
Account account = mapper.readValue(resultSet.getString(DATA), Account.class); Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
account.setId(resultSet.getLong(ID)); // account.setId(resultSet.getLong(ID));
return account; return account;
} catch (IOException e) { } catch (IOException e) {

View File

@ -18,10 +18,12 @@ import redis.clients.jedis.JedisPubSub;
public class PubSubManager { public class PubSubManager {
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final ObjectMapper mapper = SystemMapper.getMapper(); private final ObjectMapper mapper = SystemMapper.getMapper();
private final SubscriptionListener baseListener = new SubscriptionListener(); private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>(); private final Map<String, PubSubListener> listeners = new HashMap<>();
private final JedisPool jedisPool; private final JedisPool jedisPool;
private boolean subscribed = false; private boolean subscribed = false;
@ -33,25 +35,29 @@ public class PubSubManager {
} }
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) { public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
listeners.put(address, listener); listeners.put(address.serialize(), listener);
baseListener.subscribe(address.toString()); baseListener.subscribe(address.serialize());
} }
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) { public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
if (listeners.get(address) == listener) { if (listeners.get(address.serialize()) == listener) {
listeners.remove(address); listeners.remove(address.serialize());
baseListener.unsubscribe(address.toString()); baseListener.unsubscribe(address.serialize());
} }
} }
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) { public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
return publish(address.serialize(), message);
}
private synchronized boolean publish(String channel, PubSubMessage message) {
try { try {
String serialized = mapper.writeValueAsString(message); String serialized = mapper.writeValueAsString(message);
Jedis jedis = null; Jedis jedis = null;
try { try {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
return jedis.publish(address.toString(), serialized) != 0; return jedis.publish(channel, serialized) != 0;
} finally { } finally {
if (jedis != null) if (jedis != null)
jedisPool.returnResource(jedis); jedisPool.returnResource(jedis);
@ -79,7 +85,7 @@ public class PubSubManager {
Jedis jedis = null; Jedis jedis = null;
try { try {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString()); jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
logger.warn("**** Unsubscribed from holding channel!!! ******"); logger.warn("**** Unsubscribed from holding channel!!! ******");
} finally { } finally {
if (jedis != null) if (jedis != null)
@ -95,7 +101,7 @@ public class PubSubManager {
for (;;) { for (;;) {
try { try {
Thread.sleep(20000); Thread.sleep(20000);
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo")); publish(KEEPALIVE_CHANNEL, new PubSubMessage(0, "foo"));
} catch (InterruptedException e) { } catch (InterruptedException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }
@ -109,18 +115,15 @@ public class PubSubManager {
@Override @Override
public void onMessage(String channel, String message) { public void onMessage(String channel, String message) {
try { try {
WebsocketAddress address = new WebsocketAddress(channel);
PubSubListener listener; PubSubListener listener;
synchronized (PubSubManager.this) { synchronized (PubSubManager.this) {
listener = listeners.get(address); listener = listeners.get(channel);
} }
if (listener != null) { if (listener != null) {
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class)); listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
} }
} catch (InvalidWebsocketAddressException e) {
logger.warn("Address", e);
} catch (IOException e) { } catch (IOException e) {
logger.warn("IOE", e); logger.warn("IOE", e);
} }
@ -133,18 +136,12 @@ public class PubSubManager {
@Override @Override
public void onSubscribe(String channel, int count) { public void onSubscribe(String channel, int count) {
try { if (KEEPALIVE_CHANNEL.equals(channel)) {
WebsocketAddress address = new WebsocketAddress(channel);
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
synchronized (PubSubManager.this) { synchronized (PubSubManager.this) {
subscribed = true; subscribed = true;
PubSubManager.this.notifyAll(); PubSubManager.this.notifyAll();
} }
} }
} catch (InvalidWebsocketAddressException e) {
logger.warn("Weird address", e);
}
} }
@Override @Override

View File

@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.io.IOException; import java.io.IOException;
import java.util.LinkedList; import java.util.LinkedList;
@ -53,18 +54,30 @@ public class StoredMessages {
this.jedisPool = jedisPool; this.jedisPool = jedisPool;
} }
public void insert(long accountId, long deviceId, PendingMessage message) { public void clear(WebsocketAddress address) {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
jedis.del(getKey(address));
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
}
}
public void insert(WebsocketAddress address, PendingMessage message) {
Jedis jedis = null; Jedis jedis = null;
try { try {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
String serializedMessage = mapper.writeValueAsString(message); String serializedMessage = mapper.writeValueAsString(message);
long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage); long queueSize = jedis.lpush(getKey(address), serializedMessage);
queueSizeHistogram.update(queueSize); queueSizeHistogram.update(queueSize);
if (queueSize > 1000) { if (queueSize > 1000) {
jedis.ltrim(getKey(accountId, deviceId), 0, 999); jedis.ltrim(getKey(address), 0, 999);
} }
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
@ -75,7 +88,7 @@ public class StoredMessages {
} }
} }
public List<PendingMessage> getMessagesForDevice(long accountId, long deviceId) { public List<PendingMessage> getMessagesForDevice(WebsocketAddress address) {
List<PendingMessage> messages = new LinkedList<>(); List<PendingMessage> messages = new LinkedList<>();
Jedis jedis = null; Jedis jedis = null;
@ -83,7 +96,7 @@ public class StoredMessages {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
String message; String message;
while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) { while ((message = jedis.rpop(getKey(address))) != null) {
try { try {
messages.add(mapper.readValue(message, PendingMessage.class)); messages.add(mapper.readValue(message, PendingMessage.class));
} catch (IOException e) { } catch (IOException e) {
@ -98,8 +111,8 @@ public class StoredMessages {
} }
} }
private String getKey(long accountId, long deviceId) { private String getKey(WebsocketAddress address) {
return QUEUE_PREFIX + ":" + accountId + ":" + deviceId; return QUEUE_PREFIX + ":" + address.serialize();
} }
} }

View File

@ -2,39 +2,20 @@ package org.whispersystems.textsecuregcm.websocket;
public class WebsocketAddress { public class WebsocketAddress {
private final long accountId; private final String number;
private final long deviceId; private final long deviceId;
public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException { public WebsocketAddress(String number, long deviceId) {
try { this.number = number;
String[] parts = serialized.split(":");
if (parts == null || parts.length != 2) {
throw new InvalidWebsocketAddressException(serialized);
}
this.accountId = Long.parseLong(parts[0]);
this.deviceId = Long.parseLong(parts[1]);
} catch (NumberFormatException e) {
throw new InvalidWebsocketAddressException(e);
}
}
public WebsocketAddress(long accountId, long deviceId) {
this.accountId = accountId;
this.deviceId = deviceId; this.deviceId = deviceId;
} }
public long getAccountId() { public String serialize() {
return accountId; return number + ":" + deviceId;
}
public long getDeviceId() {
return deviceId;
} }
public String toString() { public String toString() {
return accountId + ":" + deviceId; return serialize();
} }
@Override @Override
@ -45,13 +26,13 @@ public class WebsocketAddress {
WebsocketAddress that = (WebsocketAddress)other; WebsocketAddress that = (WebsocketAddress)other;
return return
this.accountId == that.accountId && this.number.equals(that.number) &&
this.deviceId == that.deviceId; this.deviceId == that.deviceId;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return (int)accountId ^ (int)deviceId; return number.hashCode() ^ (int)deviceId;
} }
} }

View File

@ -13,6 +13,7 @@ import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
@ -31,6 +32,7 @@ public class AccountControllerTest {
private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class ); private RateLimiter rateLimiter = mock(RateLimiter.class );
private SmsSender smsSender = mock(SmsSender.class ); private SmsSender smsSender = mock(SmsSender.class );
private StoredMessages storedMessages = mock(StoredMessages.class );
@Rule @Rule
public final ResourceTestRule resources = ResourceTestRule.builder() public final ResourceTestRule resources = ResourceTestRule.builder()
@ -38,7 +40,8 @@ public class AccountControllerTest {
.addResource(new AccountController(pendingAccountsManager, .addResource(new AccountController(pendingAccountsManager,
accountsManager, accountsManager,
rateLimiters, rateLimiters,
smsSender)) smsSender,
storedMessages))
.build(); .build();

View File

@ -94,7 +94,6 @@ public class WebsocketControllerTest {
}}; }};
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
when(account.getId()).thenReturn(31337L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222"); when(account.getNumber()).thenReturn("+14152222222");
when(session.getRemote()).thenReturn(remote); when(session.getRemote()).thenReturn(remote);
@ -120,7 +119,7 @@ public class WebsocketControllerTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account)); .thenReturn(Optional.of(account));
when(storedMessages.getMessagesForDevice(account.getId(), device.getId())) when(storedMessages.getMessagesForDevice(new WebsocketAddress(account.getNumber(), device.getId())))
.thenReturn(outgoingMessages); .thenReturn(outgoingMessages);
WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, accountsManager, pushSender, storedMessages, pubSubManager); WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, accountsManager, pushSender, storedMessages, pubSubManager);
@ -128,7 +127,7 @@ public class WebsocketControllerTest {
controller.onWebSocketConnect(session); controller.onWebSocketConnect(session);
verify(pubSubManager).subscribe(eq(new WebsocketAddress(31337L, 2L)), eq((controller))); verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((controller)));
verify(remote, times(3)).sendStringByFuture(anyString()); verify(remote, times(3)).sendStringByFuture(anyString());
controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1))); controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1)));