Auto serializable transaction retry

This commit is contained in:
Moxie Marlinspike 2019-03-27 21:24:49 -07:00
parent 890b0ac301
commit c75dada340
3 changed files with 36 additions and 22 deletions

View File

@ -17,7 +17,6 @@
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.Anonymous;
@ -195,15 +194,7 @@ public class KeysController {
long deviceId = Long.parseLong(deviceIdSelector);
for (int i=0;i<20;i++) {
try {
return keys.get(destination.getNumber(), deviceId);
} catch (UnableToExecuteStatementException e) {
logger.info(e.getMessage());
}
}
throw new WebApplicationException(Response.status(500).build());
return keys.get(destination.getNumber(), deviceId);
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}

View File

@ -19,6 +19,7 @@ package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionIsolationLevel;
import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.Binder;
import org.skife.jdbi.v2.sqlobject.BinderFactory;
@ -29,6 +30,8 @@ import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.Transaction;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PreKey;
import java.lang.annotation.Annotation;
@ -39,10 +42,13 @@ import java.lang.annotation.Target;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
public abstract class Keys {
private static final Logger logger = LoggerFactory.getLogger(Keys.class);
@SqlUpdate("DELETE FROM keys WHERE number = :number AND device_id = :device_id")
abstract void remove(@Bind("number") String number, @Bind("device_id") long deviceId);
@ -61,16 +67,27 @@ public abstract class Keys {
public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId);
// Apparently transaction annotations don't work on the annotated query methods
@SuppressWarnings("WeakerAccess")
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public List<KeyRecord> get(String number) {
List<KeyRecord> getInternalWithTransaction(String number) {
return getInternal(number);
}
// Apparently transaction annotations don't work on the annotated query methods
@SuppressWarnings("WeakerAccess")
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public List<KeyRecord> get(String number, long deviceId) {
List<KeyRecord> getInternalWithTransaction(String number, long deviceId) {
return getInternal(number, deviceId);
}
public List<KeyRecord> get(String number) {
return executeAndRetrySerializableAction(() -> getInternalWithTransaction(number));
}
public List<KeyRecord> get(String number, long deviceId) {
return executeAndRetrySerializableAction(() -> getInternalWithTransaction(number, deviceId));
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, List<PreKey> keys) {
List<KeyRecord> records = keys.stream()
@ -81,6 +98,20 @@ public abstract class Keys {
append(records);
}
private List<KeyRecord> executeAndRetrySerializableAction(Callable<List<KeyRecord>> action) {
for (int i=0;i<20;i++) {
try {
return action.call();
} catch (UnableToExecuteStatementException e) {
logger.info("Serializable conflict, retrying: " + e.getMessage());
} catch (Exception e) {
throw new AssertionError(e);
}
}
throw new UnableToExecuteStatementException("Retried statement too many times!");
}
@SqlUpdate("VACUUM keys")
public abstract void vacuum();

View File

@ -241,16 +241,8 @@ public class KeysTest {
for (int i=0;i<50;i++) {
Thread thread = new Thread(() -> {
for (int j=0;j<10;j++) {
try {
List<KeyRecord> results = keys.get("+14152222222");
assertThat(results.size()).isEqualTo(2);
return;
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
throw new AssertionError();
List<KeyRecord> results = keys.get("+14152222222");
assertThat(results.size()).isEqualTo(2);
});
thread.start();
threads.add(thread);