Fix for PubSub channel.
1) Create channels based on numbers rather than DB row ids. 2) Ensure that stored messages are cleared at reregistration time.
This commit is contained in:
parent
4eb88a3e02
commit
c9a1386a55
|
@ -172,7 +172,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
deviceAuthenticator,
|
||||
Device.class, "WhisperServer"));
|
||||
|
||||
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender));
|
||||
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender, storedMessages));
|
||||
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters));
|
||||
environment.jersey().register(new DirectoryController(rateLimiters, directory));
|
||||
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));
|
||||
|
|
|
@ -34,8 +34,10 @@ import org.whispersystems.textsecuregcm.storage.Account;
|
|||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.StoredMessages;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
import org.whispersystems.textsecuregcm.util.VerificationCode;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
|
||||
|
||||
import javax.validation.Valid;
|
||||
import javax.ws.rs.Consumes;
|
||||
|
@ -65,16 +67,19 @@ public class AccountController {
|
|||
private final AccountsManager accounts;
|
||||
private final RateLimiters rateLimiters;
|
||||
private final SmsSender smsSender;
|
||||
private final StoredMessages storedMessages;
|
||||
|
||||
public AccountController(PendingAccountsManager pendingAccounts,
|
||||
AccountsManager accounts,
|
||||
RateLimiters rateLimiters,
|
||||
SmsSender smsSenderFactory)
|
||||
SmsSender smsSenderFactory,
|
||||
StoredMessages storedMessages)
|
||||
{
|
||||
this.pendingAccounts = pendingAccounts;
|
||||
this.accounts = accounts;
|
||||
this.rateLimiters = rateLimiters;
|
||||
this.smsSender = smsSenderFactory;
|
||||
this.storedMessages = storedMessages;
|
||||
}
|
||||
|
||||
@Timed
|
||||
|
@ -153,7 +158,7 @@ public class AccountController {
|
|||
account.addDevice(device);
|
||||
|
||||
accounts.create(account);
|
||||
|
||||
storedMessages.clear(new WebsocketAddress(number, Device.MASTER_ID));
|
||||
pendingAccounts.remove(number);
|
||||
|
||||
logger.debug("Stored device...");
|
||||
|
|
|
@ -94,7 +94,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
|
|||
|
||||
this.account = account.get();
|
||||
this.device = account.get().getAuthenticatedDevice().get();
|
||||
this.address = new WebsocketAddress(this.account.getId(), this.device.getId());
|
||||
this.address = new WebsocketAddress(this.account.getNumber(), this.device.getId());
|
||||
this.session = session;
|
||||
|
||||
this.session.setIdleTimeout(10 * 60 * 1000);
|
||||
|
@ -148,7 +148,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
|
|||
pushSender.sendMessage(account, device, remainingMessage);
|
||||
} catch (NotPushRegisteredException | TransientPushFailureException e) {
|
||||
logger.warn("onWebSocketClose", e);
|
||||
storedMessages.insert(account.getId(), device.getId(), remainingMessage);
|
||||
storedMessages.insert(address, remainingMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -208,7 +208,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
|
|||
}
|
||||
|
||||
private void handleQueryDatabase() {
|
||||
List<PendingMessage> messages = storedMessages.getMessagesForDevice(account.getId(), device.getId());
|
||||
List<PendingMessage> messages = storedMessages.getMessagesForDevice(address);
|
||||
|
||||
for (PendingMessage message : messages) {
|
||||
handleDeliverOutgoingMessage(message);
|
||||
|
|
|
@ -103,16 +103,16 @@ public class APNSender implements Managed {
|
|||
throws TransientPushFailureException
|
||||
{
|
||||
try {
|
||||
String serializedPendingMessage = mapper.writeValueAsString(message);
|
||||
String serializedPendingMessage = mapper.writeValueAsString(message);
|
||||
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
|
||||
|
||||
if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()),
|
||||
new PubSubMessage(PubSubMessage.TYPE_DELIVER,
|
||||
serializedPendingMessage)))
|
||||
if (pubSubManager.publish(websocketAddress, new PubSubMessage(PubSubMessage.TYPE_DELIVER,
|
||||
serializedPendingMessage)))
|
||||
{
|
||||
websocketMeter.mark();
|
||||
} else {
|
||||
memcacheSet(registrationId, account.getNumber());
|
||||
storedMessages.insert(account.getId(), device.getId(), message);
|
||||
storedMessages.insert(websocketAddress, message);
|
||||
|
||||
if (!message.isReceipt()) {
|
||||
sendPush(registrationId, serializedPendingMessage);
|
||||
|
|
|
@ -57,14 +57,14 @@ public class WebsocketSender {
|
|||
public void sendMessage(Account account, Device device, PendingMessage pendingMessage) {
|
||||
try {
|
||||
String serialized = mapper.writeValueAsString(pendingMessage);
|
||||
WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
|
||||
WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
|
||||
PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized);
|
||||
|
||||
if (pubSubManager.publish(address, pubSubMessage)) {
|
||||
onlineMeter.mark();
|
||||
} else {
|
||||
offlineMeter.mark();
|
||||
storedMessages.insert(account.getId(), device.getId(), pendingMessage);
|
||||
storedMessages.insert(address, pendingMessage);
|
||||
pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
|
||||
}
|
||||
} catch (JsonProcessingException e) {
|
||||
|
|
|
@ -22,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
|||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Optional;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -30,9 +29,6 @@ public class Account {
|
|||
|
||||
public static final int MEMCACHE_VERION = 5;
|
||||
|
||||
@JsonIgnore
|
||||
private long id;
|
||||
|
||||
@JsonProperty
|
||||
private String number;
|
||||
|
||||
|
@ -57,14 +53,6 @@ public class Account {
|
|||
this.devices = devices;
|
||||
}
|
||||
|
||||
public long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public Optional<Device> getAuthenticatedDevice() {
|
||||
return authenticatedDevice;
|
||||
}
|
||||
|
|
|
@ -92,7 +92,7 @@ public abstract class Accounts {
|
|||
{
|
||||
try {
|
||||
Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
|
||||
account.setId(resultSet.getLong(ID));
|
||||
// account.setId(resultSet.getLong(ID));
|
||||
|
||||
return account;
|
||||
} catch (IOException e) {
|
||||
|
|
|
@ -18,10 +18,12 @@ import redis.clients.jedis.JedisPubSub;
|
|||
|
||||
public class PubSubManager {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
|
||||
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||
private final SubscriptionListener baseListener = new SubscriptionListener();
|
||||
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>();
|
||||
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
|
||||
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||
private final SubscriptionListener baseListener = new SubscriptionListener();
|
||||
private final Map<String, PubSubListener> listeners = new HashMap<>();
|
||||
|
||||
private final JedisPool jedisPool;
|
||||
private boolean subscribed = false;
|
||||
|
@ -33,25 +35,29 @@ public class PubSubManager {
|
|||
}
|
||||
|
||||
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
|
||||
listeners.put(address, listener);
|
||||
baseListener.subscribe(address.toString());
|
||||
listeners.put(address.serialize(), listener);
|
||||
baseListener.subscribe(address.serialize());
|
||||
}
|
||||
|
||||
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
|
||||
if (listeners.get(address) == listener) {
|
||||
listeners.remove(address);
|
||||
baseListener.unsubscribe(address.toString());
|
||||
if (listeners.get(address.serialize()) == listener) {
|
||||
listeners.remove(address.serialize());
|
||||
baseListener.unsubscribe(address.serialize());
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
|
||||
return publish(address.serialize(), message);
|
||||
}
|
||||
|
||||
private synchronized boolean publish(String channel, PubSubMessage message) {
|
||||
try {
|
||||
String serialized = mapper.writeValueAsString(message);
|
||||
Jedis jedis = null;
|
||||
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
return jedis.publish(address.toString(), serialized) != 0;
|
||||
return jedis.publish(channel, serialized) != 0;
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
jedisPool.returnResource(jedis);
|
||||
|
@ -79,7 +85,7 @@ public class PubSubManager {
|
|||
Jedis jedis = null;
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString());
|
||||
jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
|
||||
logger.warn("**** Unsubscribed from holding channel!!! ******");
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
|
@ -95,7 +101,7 @@ public class PubSubManager {
|
|||
for (;;) {
|
||||
try {
|
||||
Thread.sleep(20000);
|
||||
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo"));
|
||||
publish(KEEPALIVE_CHANNEL, new PubSubMessage(0, "foo"));
|
||||
} catch (InterruptedException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
|
@ -109,18 +115,15 @@ public class PubSubManager {
|
|||
@Override
|
||||
public void onMessage(String channel, String message) {
|
||||
try {
|
||||
WebsocketAddress address = new WebsocketAddress(channel);
|
||||
PubSubListener listener;
|
||||
|
||||
synchronized (PubSubManager.this) {
|
||||
listener = listeners.get(address);
|
||||
listener = listeners.get(channel);
|
||||
}
|
||||
|
||||
if (listener != null) {
|
||||
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
|
||||
}
|
||||
} catch (InvalidWebsocketAddressException e) {
|
||||
logger.warn("Address", e);
|
||||
} catch (IOException e) {
|
||||
logger.warn("IOE", e);
|
||||
}
|
||||
|
@ -133,17 +136,11 @@ public class PubSubManager {
|
|||
|
||||
@Override
|
||||
public void onSubscribe(String channel, int count) {
|
||||
try {
|
||||
WebsocketAddress address = new WebsocketAddress(channel);
|
||||
|
||||
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
|
||||
synchronized (PubSubManager.this) {
|
||||
subscribed = true;
|
||||
PubSubManager.this.notifyAll();
|
||||
}
|
||||
if (KEEPALIVE_CHANNEL.equals(channel)) {
|
||||
synchronized (PubSubManager.this) {
|
||||
subscribed = true;
|
||||
PubSubManager.this.notifyAll();
|
||||
}
|
||||
} catch (InvalidWebsocketAddressException e) {
|
||||
logger.warn("Weird address", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory;
|
|||
import org.whispersystems.textsecuregcm.entities.PendingMessage;
|
||||
import org.whispersystems.textsecuregcm.util.Constants;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedList;
|
||||
|
@ -53,18 +54,30 @@ public class StoredMessages {
|
|||
this.jedisPool = jedisPool;
|
||||
}
|
||||
|
||||
public void insert(long accountId, long deviceId, PendingMessage message) {
|
||||
public void clear(WebsocketAddress address) {
|
||||
Jedis jedis = null;
|
||||
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
jedis.del(getKey(address));
|
||||
} finally {
|
||||
if (jedis != null)
|
||||
jedisPool.returnResource(jedis);
|
||||
}
|
||||
}
|
||||
|
||||
public void insert(WebsocketAddress address, PendingMessage message) {
|
||||
Jedis jedis = null;
|
||||
|
||||
try {
|
||||
jedis = jedisPool.getResource();
|
||||
|
||||
String serializedMessage = mapper.writeValueAsString(message);
|
||||
long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage);
|
||||
long queueSize = jedis.lpush(getKey(address), serializedMessage);
|
||||
queueSizeHistogram.update(queueSize);
|
||||
|
||||
if (queueSize > 1000) {
|
||||
jedis.ltrim(getKey(accountId, deviceId), 0, 999);
|
||||
jedis.ltrim(getKey(address), 0, 999);
|
||||
}
|
||||
|
||||
} catch (JsonProcessingException e) {
|
||||
|
@ -75,7 +88,7 @@ public class StoredMessages {
|
|||
}
|
||||
}
|
||||
|
||||
public List<PendingMessage> getMessagesForDevice(long accountId, long deviceId) {
|
||||
public List<PendingMessage> getMessagesForDevice(WebsocketAddress address) {
|
||||
List<PendingMessage> messages = new LinkedList<>();
|
||||
Jedis jedis = null;
|
||||
|
||||
|
@ -83,7 +96,7 @@ public class StoredMessages {
|
|||
jedis = jedisPool.getResource();
|
||||
String message;
|
||||
|
||||
while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) {
|
||||
while ((message = jedis.rpop(getKey(address))) != null) {
|
||||
try {
|
||||
messages.add(mapper.readValue(message, PendingMessage.class));
|
||||
} catch (IOException e) {
|
||||
|
@ -98,8 +111,8 @@ public class StoredMessages {
|
|||
}
|
||||
}
|
||||
|
||||
private String getKey(long accountId, long deviceId) {
|
||||
return QUEUE_PREFIX + ":" + accountId + ":" + deviceId;
|
||||
private String getKey(WebsocketAddress address) {
|
||||
return QUEUE_PREFIX + ":" + address.serialize();
|
||||
}
|
||||
|
||||
}
|
|
@ -2,39 +2,20 @@ package org.whispersystems.textsecuregcm.websocket;
|
|||
|
||||
public class WebsocketAddress {
|
||||
|
||||
private final long accountId;
|
||||
private final long deviceId;
|
||||
private final String number;
|
||||
private final long deviceId;
|
||||
|
||||
public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException {
|
||||
try {
|
||||
String[] parts = serialized.split(":");
|
||||
|
||||
if (parts == null || parts.length != 2) {
|
||||
throw new InvalidWebsocketAddressException(serialized);
|
||||
}
|
||||
|
||||
this.accountId = Long.parseLong(parts[0]);
|
||||
this.deviceId = Long.parseLong(parts[1]);
|
||||
} catch (NumberFormatException e) {
|
||||
throw new InvalidWebsocketAddressException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public WebsocketAddress(long accountId, long deviceId) {
|
||||
this.accountId = accountId;
|
||||
public WebsocketAddress(String number, long deviceId) {
|
||||
this.number = number;
|
||||
this.deviceId = deviceId;
|
||||
}
|
||||
|
||||
public long getAccountId() {
|
||||
return accountId;
|
||||
}
|
||||
|
||||
public long getDeviceId() {
|
||||
return deviceId;
|
||||
public String serialize() {
|
||||
return number + ":" + deviceId;
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return accountId + ":" + deviceId;
|
||||
return serialize();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -45,13 +26,13 @@ public class WebsocketAddress {
|
|||
WebsocketAddress that = (WebsocketAddress)other;
|
||||
|
||||
return
|
||||
this.accountId == that.accountId &&
|
||||
this.number.equals(that.number) &&
|
||||
this.deviceId == that.deviceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return (int)accountId ^ (int)deviceId;
|
||||
return number.hashCode() ^ (int)deviceId;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.whispersystems.textsecuregcm.sms.SmsSender;
|
|||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.StoredMessages;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
|
||||
import javax.ws.rs.core.MediaType;
|
||||
|
@ -31,6 +32,7 @@ public class AccountControllerTest {
|
|||
private RateLimiters rateLimiters = mock(RateLimiters.class );
|
||||
private RateLimiter rateLimiter = mock(RateLimiter.class );
|
||||
private SmsSender smsSender = mock(SmsSender.class );
|
||||
private StoredMessages storedMessages = mock(StoredMessages.class );
|
||||
|
||||
@Rule
|
||||
public final ResourceTestRule resources = ResourceTestRule.builder()
|
||||
|
@ -38,7 +40,8 @@ public class AccountControllerTest {
|
|||
.addResource(new AccountController(pendingAccountsManager,
|
||||
accountsManager,
|
||||
rateLimiters,
|
||||
smsSender))
|
||||
smsSender,
|
||||
storedMessages))
|
||||
.build();
|
||||
|
||||
|
||||
|
|
|
@ -94,7 +94,6 @@ public class WebsocketControllerTest {
|
|||
}};
|
||||
|
||||
when(device.getId()).thenReturn(2L);
|
||||
when(account.getId()).thenReturn(31337L);
|
||||
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
|
||||
when(account.getNumber()).thenReturn("+14152222222");
|
||||
when(session.getRemote()).thenReturn(remote);
|
||||
|
@ -120,7 +119,7 @@ public class WebsocketControllerTest {
|
|||
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
|
||||
.thenReturn(Optional.of(account));
|
||||
|
||||
when(storedMessages.getMessagesForDevice(account.getId(), device.getId()))
|
||||
when(storedMessages.getMessagesForDevice(new WebsocketAddress(account.getNumber(), device.getId())))
|
||||
.thenReturn(outgoingMessages);
|
||||
|
||||
WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, accountsManager, pushSender, storedMessages, pubSubManager);
|
||||
|
@ -128,7 +127,7 @@ public class WebsocketControllerTest {
|
|||
|
||||
controller.onWebSocketConnect(session);
|
||||
|
||||
verify(pubSubManager).subscribe(eq(new WebsocketAddress(31337L, 2L)), eq((controller)));
|
||||
verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((controller)));
|
||||
verify(remote, times(3)).sendStringByFuture(anyString());
|
||||
|
||||
controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1)));
|
||||
|
|
Loading…
Reference in New Issue