diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index c745e81a5..968ab382c 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -172,7 +172,7 @@ public class WhisperServerService extends Application messages = storedMessages.getMessagesForDevice(account.getId(), device.getId()); + List messages = storedMessages.getMessagesForDevice(address); for (PendingMessage message : messages) { handleDeliverOutgoingMessage(message); diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java index 734ffd36b..06272cce2 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java @@ -103,16 +103,16 @@ public class APNSender implements Managed { throws TransientPushFailureException { try { - String serializedPendingMessage = mapper.writeValueAsString(message); + String serializedPendingMessage = mapper.writeValueAsString(message); + WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId()); - if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()), - new PubSubMessage(PubSubMessage.TYPE_DELIVER, - serializedPendingMessage))) + if (pubSubManager.publish(websocketAddress, new PubSubMessage(PubSubMessage.TYPE_DELIVER, + serializedPendingMessage))) { websocketMeter.mark(); } else { memcacheSet(registrationId, account.getNumber()); - storedMessages.insert(account.getId(), device.getId(), message); + storedMessages.insert(websocketAddress, message); if (!message.isReceipt()) { sendPush(registrationId, serializedPendingMessage); diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java index 2bbc8d763..fe3cd00e4 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java @@ -57,14 +57,14 @@ public class WebsocketSender { public void sendMessage(Account account, Device device, PendingMessage pendingMessage) { try { String serialized = mapper.writeValueAsString(pendingMessage); - WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId()); + WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized); if (pubSubManager.publish(address, pubSubMessage)) { onlineMeter.mark(); } else { offlineMeter.mark(); - storedMessages.insert(account.getId(), device.getId(), pendingMessage); + storedMessages.insert(address, pendingMessage); pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null)); } } catch (JsonProcessingException e) { diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index b7ea1743f..63019bd00 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -22,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; -import java.io.Serializable; import java.util.LinkedList; import java.util.List; @@ -30,9 +29,6 @@ public class Account { public static final int MEMCACHE_VERION = 5; - @JsonIgnore - private long id; - @JsonProperty private String number; @@ -57,14 +53,6 @@ public class Account { this.devices = devices; } - public long getId() { - return id; - } - - public void setId(long id) { - this.id = id; - } - public Optional getAuthenticatedDevice() { return authenticatedDevice; } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 57cabcbb9..9eccb834d 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -92,7 +92,7 @@ public abstract class Accounts { { try { Account account = mapper.readValue(resultSet.getString(DATA), Account.class); - account.setId(resultSet.getLong(ID)); +// account.setId(resultSet.getLong(ID)); return account; } catch (IOException e) { diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java index 4ebb4e841..3b4acd0ef 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java @@ -18,10 +18,12 @@ import redis.clients.jedis.JedisPubSub; public class PubSubManager { - private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); - private final ObjectMapper mapper = SystemMapper.getMapper(); - private final SubscriptionListener baseListener = new SubscriptionListener(); - private final Map listeners = new HashMap<>(); + private static final String KEEPALIVE_CHANNEL = "KEEPALIVE"; + + private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); + private final ObjectMapper mapper = SystemMapper.getMapper(); + private final SubscriptionListener baseListener = new SubscriptionListener(); + private final Map listeners = new HashMap<>(); private final JedisPool jedisPool; private boolean subscribed = false; @@ -33,25 +35,29 @@ public class PubSubManager { } public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) { - listeners.put(address, listener); - baseListener.subscribe(address.toString()); + listeners.put(address.serialize(), listener); + baseListener.subscribe(address.serialize()); } public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) { - if (listeners.get(address) == listener) { - listeners.remove(address); - baseListener.unsubscribe(address.toString()); + if (listeners.get(address.serialize()) == listener) { + listeners.remove(address.serialize()); + baseListener.unsubscribe(address.serialize()); } } public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) { + return publish(address.serialize(), message); + } + + private synchronized boolean publish(String channel, PubSubMessage message) { try { String serialized = mapper.writeValueAsString(message); Jedis jedis = null; try { jedis = jedisPool.getResource(); - return jedis.publish(address.toString(), serialized) != 0; + return jedis.publish(channel, serialized) != 0; } finally { if (jedis != null) jedisPool.returnResource(jedis); @@ -79,7 +85,7 @@ public class PubSubManager { Jedis jedis = null; try { jedis = jedisPool.getResource(); - jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString()); + jedis.subscribe(baseListener, KEEPALIVE_CHANNEL); logger.warn("**** Unsubscribed from holding channel!!! ******"); } finally { if (jedis != null) @@ -95,7 +101,7 @@ public class PubSubManager { for (;;) { try { Thread.sleep(20000); - publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo")); + publish(KEEPALIVE_CHANNEL, new PubSubMessage(0, "foo")); } catch (InterruptedException e) { throw new AssertionError(e); } @@ -109,18 +115,15 @@ public class PubSubManager { @Override public void onMessage(String channel, String message) { try { - WebsocketAddress address = new WebsocketAddress(channel); PubSubListener listener; synchronized (PubSubManager.this) { - listener = listeners.get(address); + listener = listeners.get(channel); } if (listener != null) { listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class)); } - } catch (InvalidWebsocketAddressException e) { - logger.warn("Address", e); } catch (IOException e) { logger.warn("IOE", e); } @@ -133,17 +136,11 @@ public class PubSubManager { @Override public void onSubscribe(String channel, int count) { - try { - WebsocketAddress address = new WebsocketAddress(channel); - - if (address.getAccountId() == 0 && address.getDeviceId() == 0) { - synchronized (PubSubManager.this) { - subscribed = true; - PubSubManager.this.notifyAll(); - } + if (KEEPALIVE_CHANNEL.equals(channel)) { + synchronized (PubSubManager.this) { + subscribed = true; + PubSubManager.this.notifyAll(); } - } catch (InvalidWebsocketAddressException e) { - logger.warn("Weird address", e); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java index b40b170a2..4ba277041 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java @@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.io.IOException; import java.util.LinkedList; @@ -53,18 +54,30 @@ public class StoredMessages { this.jedisPool = jedisPool; } - public void insert(long accountId, long deviceId, PendingMessage message) { + public void clear(WebsocketAddress address) { + Jedis jedis = null; + + try { + jedis = jedisPool.getResource(); + jedis.del(getKey(address)); + } finally { + if (jedis != null) + jedisPool.returnResource(jedis); + } + } + + public void insert(WebsocketAddress address, PendingMessage message) { Jedis jedis = null; try { jedis = jedisPool.getResource(); String serializedMessage = mapper.writeValueAsString(message); - long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage); + long queueSize = jedis.lpush(getKey(address), serializedMessage); queueSizeHistogram.update(queueSize); if (queueSize > 1000) { - jedis.ltrim(getKey(accountId, deviceId), 0, 999); + jedis.ltrim(getKey(address), 0, 999); } } catch (JsonProcessingException e) { @@ -75,7 +88,7 @@ public class StoredMessages { } } - public List getMessagesForDevice(long accountId, long deviceId) { + public List getMessagesForDevice(WebsocketAddress address) { List messages = new LinkedList<>(); Jedis jedis = null; @@ -83,7 +96,7 @@ public class StoredMessages { jedis = jedisPool.getResource(); String message; - while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) { + while ((message = jedis.rpop(getKey(address))) != null) { try { messages.add(mapper.readValue(message, PendingMessage.class)); } catch (IOException e) { @@ -98,8 +111,8 @@ public class StoredMessages { } } - private String getKey(long accountId, long deviceId) { - return QUEUE_PREFIX + ":" + accountId + ":" + deviceId; + private String getKey(WebsocketAddress address) { + return QUEUE_PREFIX + ":" + address.serialize(); } } \ No newline at end of file diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java index 72f72778c..fbd381c34 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java @@ -2,39 +2,20 @@ package org.whispersystems.textsecuregcm.websocket; public class WebsocketAddress { - private final long accountId; - private final long deviceId; + private final String number; + private final long deviceId; - public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException { - try { - String[] parts = serialized.split(":"); - - if (parts == null || parts.length != 2) { - throw new InvalidWebsocketAddressException(serialized); - } - - this.accountId = Long.parseLong(parts[0]); - this.deviceId = Long.parseLong(parts[1]); - } catch (NumberFormatException e) { - throw new InvalidWebsocketAddressException(e); - } - } - - public WebsocketAddress(long accountId, long deviceId) { - this.accountId = accountId; + public WebsocketAddress(String number, long deviceId) { + this.number = number; this.deviceId = deviceId; } - public long getAccountId() { - return accountId; - } - - public long getDeviceId() { - return deviceId; + public String serialize() { + return number + ":" + deviceId; } public String toString() { - return accountId + ":" + deviceId; + return serialize(); } @Override @@ -45,13 +26,13 @@ public class WebsocketAddress { WebsocketAddress that = (WebsocketAddress)other; return - this.accountId == that.accountId && + this.number.equals(that.number) && this.deviceId == that.deviceId; } @Override public int hashCode() { - return (int)accountId ^ (int)deviceId; + return number.hashCode() ^ (int)deviceId; } } diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index e70b4a373..819808c8f 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -13,6 +13,7 @@ import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; +import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import javax.ws.rs.core.MediaType; @@ -31,6 +32,7 @@ public class AccountControllerTest { private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiter rateLimiter = mock(RateLimiter.class ); private SmsSender smsSender = mock(SmsSender.class ); + private StoredMessages storedMessages = mock(StoredMessages.class ); @Rule public final ResourceTestRule resources = ResourceTestRule.builder() @@ -38,7 +40,8 @@ public class AccountControllerTest { .addResource(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, - smsSender)) + smsSender, + storedMessages)) .build(); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java index 86a950a5a..16424c17f 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java @@ -94,7 +94,6 @@ public class WebsocketControllerTest { }}; when(device.getId()).thenReturn(2L); - when(account.getId()).thenReturn(31337L); when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); when(session.getRemote()).thenReturn(remote); @@ -120,7 +119,7 @@ public class WebsocketControllerTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) .thenReturn(Optional.of(account)); - when(storedMessages.getMessagesForDevice(account.getId(), device.getId())) + when(storedMessages.getMessagesForDevice(new WebsocketAddress(account.getNumber(), device.getId()))) .thenReturn(outgoingMessages); WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, accountsManager, pushSender, storedMessages, pubSubManager); @@ -128,7 +127,7 @@ public class WebsocketControllerTest { controller.onWebSocketConnect(session); - verify(pubSubManager).subscribe(eq(new WebsocketAddress(31337L, 2L)), eq((controller))); + verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((controller))); verify(remote, times(3)).sendStringByFuture(anyString()); controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1)));