Rearrange provisioning flow. Add needsMessageSync response.

// FREEBIE
This commit is contained in:
Moxie Marlinspike 2015-01-21 13:56:58 -08:00
parent d2dbff173a
commit f7132bdbbc
11 changed files with 131 additions and 51 deletions

View File

@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.controllers.KeepAliveController;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV1; import org.whispersystems.textsecuregcm.controllers.KeysControllerV1;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV2; import org.whispersystems.textsecuregcm.controllers.KeysControllerV2;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.ProvisioningController;
import org.whispersystems.textsecuregcm.controllers.ReceiptController; import org.whispersystems.textsecuregcm.controllers.ReceiptController;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.FederatedPeer; import org.whispersystems.textsecuregcm.federation.FederatedPeer;
@ -182,6 +183,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1)); environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));
environment.jersey().register(new FederationControllerV2(accountsManager, attachmentController, messageController, keysControllerV2)); environment.jersey().register(new FederationControllerV2(accountsManager, attachmentController, messageController, keysControllerV2));
environment.jersey().register(new ReceiptController(accountsManager, federatedClientManager, pushSender)); environment.jersey().register(new ReceiptController(accountsManager, federatedClientManager, pushSender));
environment.jersey().register(new ProvisioningController(rateLimiters, pushSender));
environment.jersey().register(attachmentController); environment.jersey().register(attachmentController);
environment.jersey().register(keysControllerV1); environment.jersey().register(keysControllerV1);
environment.jersey().register(keysControllerV2); environment.jersey().register(keysControllerV2);
@ -203,10 +205,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet ); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet );
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
websocket.addMapping("/v1/websocket/*"); websocket.addMapping("/v1/websocket/");
websocket.setAsyncSupported(true); websocket.setAsyncSupported(true);
provisioning.addMapping("/v1/provisioning/*"); provisioning.addMapping("/v1/websocket/provisioning/");
provisioning.setAsyncSupported(true); provisioning.setAsyncSupported(true);
webSocketServlet.start(); webSocketServlet.start();

View File

@ -44,6 +44,7 @@ import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom; import java.security.SecureRandom;
@ -69,7 +70,7 @@ public class DeviceController {
@Timed @Timed
@GET @GET
@Path("/provisioning_code") @Path("/provisioning/code")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public VerificationCode createDeviceToken(@Auth Account account) public VerificationCode createDeviceToken(@Auth Account account)
throws RateLimitExceededException throws RateLimitExceededException
@ -102,7 +103,7 @@ public class DeviceController {
Optional<String> storedVerificationCode = pendingDevices.getCodeForNumber(number); Optional<String> storedVerificationCode = pendingDevices.getCodeForNumber(number);
if (!storedVerificationCode.isPresent() || if (!storedVerificationCode.isPresent() ||
!verificationCode.equals(storedVerificationCode.get())) !MessageDigest.isEqual(verificationCode.getBytes(), storedVerificationCode.get().getBytes()))
{ {
throw new WebApplicationException(Response.status(403).build()); throw new WebApplicationException(Response.status(403).build());
} }

View File

@ -0,0 +1,7 @@
package org.whispersystems.textsecuregcm.controllers;
public class InvalidDestinationException extends Exception {
public InvalidDestinationException(String message) {
super(message);
}
}

View File

@ -26,7 +26,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse; import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient; import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
@ -39,8 +39,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -85,16 +83,21 @@ public class MessageController {
@Path("/{destination}") @Path("/{destination}")
@PUT @PUT
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void sendMessage(@Auth Account source, @Produces(MediaType.APPLICATION_JSON)
@PathParam("destination") String destinationName, public SendMessageResponse sendMessage(@Auth Account source,
@Valid IncomingMessageList messages) @PathParam("destination") String destinationName,
@Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException throws IOException, RateLimitExceededException
{ {
rateLimiters.getMessagesLimiter().validate(source.getNumber()); rateLimiters.getMessagesLimiter().validate(source.getNumber());
try { try {
if (messages.getRelay() == null) sendLocalMessage(source, destinationName, messages); boolean isSyncMessage = source.getNumber().equals(destinationName);
else sendRelayMessage(source, destinationName, messages);
if (messages.getRelay() == null) sendLocalMessage(source, destinationName, messages, isSyncMessage);
else sendRelayMessage(source, destinationName, messages, isSyncMessage);
return new SendMessageResponse(!isSyncMessage && source.getActiveDeviceCount() > 1);
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build()); throw new WebApplicationException(Response.status(404).build());
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
@ -108,6 +111,8 @@ public class MessageController {
.type(MediaType.APPLICATION_JSON) .type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices())) .entity(new StaleDevices(e.getStaleDevices()))
.build()); .build());
} catch (InvalidDestinationException e) {
throw new WebApplicationException(Response.status(400).build());
} }
} }
@ -131,29 +136,18 @@ public class MessageController {
} }
} }
@Timed
@PUT
@Path("/provisioning/{destination}")
@Consumes(MediaType.APPLICATION_JSON)
public void sendProvisioningMessage(@Auth Account source,
@PathParam("destination") String destinationName,
@Valid ProvisioningMessage message)
throws RateLimitExceededException, InvalidWebsocketAddressException, IOException
{
rateLimiters.getMessagesLimiter().validate(source.getNumber());
pushSender.getWebSocketSender().sendProvisioningMessage(new ProvisioningAddress(destinationName),
Base64.decode(message.getBody()));
}
private void sendLocalMessage(Account source, private void sendLocalMessage(Account source,
String destinationName, String destinationName,
IncomingMessageList messages) IncomingMessageList messages,
boolean isSyncMessage)
throws NoSuchUserException, MismatchedDevicesException, IOException, StaleDevicesException throws NoSuchUserException, MismatchedDevicesException, IOException, StaleDevicesException
{ {
Account destination = getDestinationAccount(destinationName); Account destination;
validateCompleteDeviceList(destination, messages.getMessages()); if (!isSyncMessage) destination = getDestinationAccount(destinationName);
else destination = source;
validateCompleteDeviceList(destination, messages.getMessages(), isSyncMessage);
validateRegistrationIds(destination, messages.getMessages()); validateRegistrationIds(destination, messages.getMessages());
for (IncomingMessage incomingMessage : messages.getMessages()) { for (IncomingMessage incomingMessage : messages.getMessages()) {
@ -201,9 +195,12 @@ public class MessageController {
private void sendRelayMessage(Account source, private void sendRelayMessage(Account source,
String destinationName, String destinationName,
IncomingMessageList messages) IncomingMessageList messages,
throws IOException, NoSuchUserException boolean isSyncMessage)
throws IOException, NoSuchUserException, InvalidDestinationException
{ {
if (isSyncMessage) throw new InvalidDestinationException("Transcript messages can't be relayed!");
try { try {
FederatedClient client = federatedClientManager.getClient(messages.getRelay()); FederatedClient client = federatedClientManager.getClient(messages.getRelay());
client.sendMessages(source.getNumber(), source.getAuthenticatedDevice().get().getId(), client.sendMessages(source.getNumber(), source.getAuthenticatedDevice().get().getId(),
@ -246,7 +243,9 @@ public class MessageController {
} }
} }
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages) private void validateCompleteDeviceList(Account account,
List<IncomingMessage> messages,
boolean isSyncMessage)
throws MismatchedDevicesException throws MismatchedDevicesException
{ {
Set<Long> messageDeviceIds = new HashSet<>(); Set<Long> messageDeviceIds = new HashSet<>();
@ -260,7 +259,9 @@ public class MessageController {
} }
for (Device device : account.getDevices()) { for (Device device : account.getDevices()) {
if (device.isActive()) { if (device.isActive() &&
!(isSyncMessage && device.getId() == account.getAuthenticatedDevice().get().getId()))
{
accountDeviceIds.add(device.getId()); accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) { if (!messageDeviceIds.contains(device.getId())) {

View File

@ -0,0 +1,53 @@
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.WebsocketSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import io.dropwizard.auth.Auth;
@Path("/v1/provisioning")
public class ProvisioningController {
private final RateLimiters rateLimiters;
private final WebsocketSender websocketSender;
public ProvisioningController(RateLimiters rateLimiters, PushSender pushSender) {
this.rateLimiters = rateLimiters;
this.websocketSender = pushSender.getWebSocketSender();
}
@Timed
@Path("/{destination}")
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void sendProvisioningMessage(@Auth Account source,
@PathParam("destination") String destinationName,
@Valid ProvisioningMessage message)
throws RateLimitExceededException, InvalidWebsocketAddressException, IOException
{
rateLimiters.getMessagesLimiter().validate(source.getNumber());
if (!websocketSender.sendProvisioningMessage(new ProvisioningAddress(destinationName),
Base64.decode(message.getBody())))
{
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
}
}

View File

@ -1,18 +1,14 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
public class ProvisioningMessage { public class ProvisioningMessage {
@JsonProperty @JsonProperty
@NotEmpty
private String body; private String body;
public ProvisioningMessage() {}
public ProvisioningMessage(String body) {
this.body = body;
}
public String getBody() { public String getBody() {
return body; return body;
} }

View File

@ -0,0 +1,16 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
public class SendMessageResponse {
@JsonProperty
private boolean needsSync;
public SendMessageResponse() {}
public SendMessageResponse(boolean needsSync) {
this.needsSync = needsSync;
}
}

View File

@ -121,6 +121,16 @@ public class Account {
return highestDevice + 1; return highestDevice + 1;
} }
public int getActiveDeviceCount() {
int count = 0;
for (Device device : devices) {
if (device.isActive()) count++;
}
return count;
}
public boolean isRateLimited() { public boolean isRateLimited() {
return true; return true;
} }

View File

@ -7,17 +7,11 @@ import java.security.SecureRandom;
public class ProvisioningAddress extends WebsocketAddress { public class ProvisioningAddress extends WebsocketAddress {
private static final String PREFIX = ">>ephemeral-";
private final String address; private final String address;
public ProvisioningAddress(String address) throws InvalidWebsocketAddressException { public ProvisioningAddress(String address) throws InvalidWebsocketAddressException {
super(address, 0); super(address, 0);
this.address = address; this.address = address;
if (address == null || !address.startsWith(PREFIX)) {
throw new InvalidWebsocketAddressException(address);
}
} }
public String getAddress() { public String getAddress() {
@ -29,8 +23,8 @@ public class ProvisioningAddress extends WebsocketAddress {
byte[] random = new byte[16]; byte[] random = new byte[16];
SecureRandom.getInstance("SHA1PRNG").nextBytes(random); SecureRandom.getInstance("SHA1PRNG").nextBytes(random);
return new ProvisioningAddress(PREFIX + Base64.encodeBytesWithoutPadding(random) return new ProvisioningAddress(Base64.encodeBytesWithoutPadding(random)
.replace('+', '-').replace('/', '_')); .replace('+', '-').replace('/', '_'));
} catch (NoSuchAlgorithmException | InvalidWebsocketAddressException e) { } catch (NoSuchAlgorithmException | InvalidWebsocketAddressException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }

View File

@ -82,7 +82,7 @@ public class DeviceControllerTest {
@Test @Test
public void validDeviceRegisterTest() throws Exception { public void validDeviceRegisterTest() throws Exception {
VerificationCode deviceCode = resources.client().resource("/v1/devices/provisioning_code") VerificationCode deviceCode = resources.client().resource("/v1/devices/provisioning/code")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class); .get(VerificationCode.class);

View File

@ -98,7 +98,7 @@ public class MessageControllerTest {
.type(MediaType.APPLICATION_JSON_TYPE) .type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class); .put(ClientResponse.class);
assertThat("Good Response", response.getStatus(), is(equalTo(204))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
verify(pushSender, times(1)).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class)); verify(pushSender, times(1)).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class));
} }
@ -148,7 +148,7 @@ public class MessageControllerTest {
.type(MediaType.APPLICATION_JSON_TYPE) .type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class); .put(ClientResponse.class);
assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(pushSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class)); verify(pushSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class));
} }