Introduce experiment comparison methods for suppliers.

This commit is contained in:
Jon Chambers 2020-06-09 17:14:46 -04:00 committed by Jon Chambers
parent 0713da7393
commit 0671f05c05
10 changed files with 112 additions and 35 deletions

View File

@ -10,6 +10,7 @@ import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
/**
* An experiment compares the results of two operations and records metrics to assess how frequently they match.
@ -40,7 +41,7 @@ public class Experiment {
this.meterRegistry = meterRegistry;
}
public <T> void compareResult(final T expected, final CompletionStage<T> experimentStage) {
public <T> void compareFutureResult(final T expected, final CompletionStage<T> experimentStage) {
// We're assuming that we get the experiment completion stage as soon as possible after it's started, but this
// start time will inescapably have some "wiggle" to it.
final long start = System.nanoTime();
@ -49,18 +50,37 @@ public class Experiment {
final long duration = System.nanoTime() - start;
if (cause != null) {
meterRegistry.timer(timerName,
List.of(Tag.of(OUTCOME_TAG, ERROR_OUTCOME), Tag.of(CAUSE_TAG, cause.getClass().getSimpleName())))
.record(duration, TimeUnit.NANOSECONDS);
recordError(duration, cause);
} else {
final boolean shouldIgnore = actual == null && expected != null;
if (!shouldIgnore) {
meterRegistry.timer(timerName,
List.of(Tag.of(OUTCOME_TAG, Objects.equals(expected, actual) ? MATCH_OUTCOME : MISMATCH_OUTCOME)))
.record(duration, TimeUnit.NANOSECONDS);
}
recordResult(duration, expected, actual);
}
});
}
public <T> void compareSupplierResult(final T expected, final Supplier<T> experimentSupplier) {
final long start = System.nanoTime();
try {
final T result = experimentSupplier.get();
recordResult(System.nanoTime() - start, expected, result);
} catch (final Exception e) {
recordError(System.nanoTime() - start, e);
}
}
private void recordError(final long durationNanos, final Throwable cause) {
meterRegistry.timer(timerName,
List.of(Tag.of(OUTCOME_TAG, ERROR_OUTCOME), Tag.of(CAUSE_TAG, cause.getClass().getSimpleName())))
.record(durationNanos, TimeUnit.NANOSECONDS);
}
private <T> void recordResult(final long durationNanos, final T expected, final T actual) {
final boolean shouldIgnore = actual == null && expected != null;
if (!shouldIgnore) {
meterRegistry.timer(timerName,
List.of(Tag.of(OUTCOME_TAG, Objects.equals(expected, actual) ? MATCH_OUTCOME : MISMATCH_OUTCOME)))
.record(durationNanos, TimeUnit.NANOSECONDS);
}
}
}

View File

@ -113,7 +113,7 @@ public class RateLimiter {
final String bucketName = getBucketName(key);
String serialized = jedis.get(bucketName);
redisClusterExperiment.compareResult(serialized, cacheCluster.withReadCluster(connection -> connection.async().get(bucketName)));
redisClusterExperiment.compareFutureResult(serialized, cacheCluster.withReadCluster(connection -> connection.async().get(bucketName)));
if (serialized != null) {
return LeakyBucket.fromSerialized(mapper, serialized);

View File

@ -60,7 +60,7 @@ public class AccountDatabaseCrawlerCache {
public boolean isAccelerated() {
try (Jedis jedis = jedisPool.getWriteResource()) {
final String accelerated = jedis.get(ACCELERATE_KEY);
redisClusterExperiment.compareResult(accelerated, cacheCluster.withReadCluster(connection -> connection.async().get(ACCELERATE_KEY)));
redisClusterExperiment.compareFutureResult(accelerated, cacheCluster.withReadCluster(connection -> connection.async().get(ACCELERATE_KEY)));
return "1".equals(accelerated);
}
@ -88,7 +88,7 @@ public class AccountDatabaseCrawlerCache {
public Optional<UUID> getLastUuid() {
try (Jedis jedis = jedisPool.getWriteResource()) {
String lastUuidString = jedis.get(LAST_UUID_KEY);
redisClusterExperiment.compareResult(lastUuidString, cacheCluster.withWriteCluster(connection -> connection.async().get(LAST_UUID_KEY)));
redisClusterExperiment.compareFutureResult(lastUuidString, cacheCluster.withWriteCluster(connection -> connection.async().get(LAST_UUID_KEY)));
if (lastUuidString == null) return Optional.empty();
else return Optional.of(UUID.fromString(lastUuidString));

View File

@ -179,7 +179,7 @@ public class AccountsManager {
final String key = getAccountMapKey(number);
String uuid = jedis.get(key);
redisClusterExperiment.compareResult(uuid, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
redisClusterExperiment.compareFutureResult(uuid, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
if (uuid != null) return redisGet(jedis, UUID.fromString(uuid));
else return Optional.empty();
@ -203,7 +203,7 @@ public class AccountsManager {
final String key = getAccountEntityKey(uuid);
String json = jedis.get(key);
redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
if (json != null) {
Account account = mapper.readValue(json, Account.class);

View File

@ -165,7 +165,7 @@ public class ActiveUserCounter extends AccountDatabaseCrawlerListener {
private void incrementTallies(UUID fromUuid, Map<String, long[]> platformIncrements, Map<String, long[]> countryIncrements) {
try (Jedis jedis = jedisPool.getWriteResource()) {
String tallyValue = jedis.get(TALLY_KEY);
redisClusterExperiment.compareResult(tallyValue, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY)));
redisClusterExperiment.compareFutureResult(tallyValue, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY)));
ActiveUserTally activeUserTally;
@ -210,7 +210,7 @@ public class ActiveUserCounter extends AccountDatabaseCrawlerListener {
private ActiveUserTally getFinalTallies() {
try (Jedis jedis = jedisPool.getReadResource()) {
final String tallyJson = jedis.get(TALLY_KEY);
redisClusterExperiment.compareResult(tallyJson, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY)));
redisClusterExperiment.compareFutureResult(tallyJson, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY)));
return mapper.readValue(tallyJson, ActiveUserTally.class);
} catch (IOException e) {

View File

@ -89,7 +89,7 @@ public class PendingAccountsManager {
final String key = CACHE_PREFIX + number;
String json = jedis.get(key);
redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
if (json == null) return Optional.empty();
else return Optional.of(mapper.readValue(json, StoredVerificationCode.class));

View File

@ -88,7 +88,7 @@ public class PendingDevicesManager {
final String key = CACHE_PREFIX + number;
String json = jedis.get(key);
redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
if (json == null) return Optional.empty();
else return Optional.of(mapper.readValue(json, StoredVerificationCode.class));

View File

@ -74,7 +74,7 @@ public class ProfilesManager {
final String key = CACHE_PREFIX + uuid.toString();
String json = jedis.hget(key, version);
redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().hget(key, version)));
redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().hget(key, version)));
if (json == null) return Optional.empty();
else return Optional.of(mapper.readValue(json, VersionedProfile.class));

View File

@ -148,7 +148,7 @@ public class UsernamesManager {
final String key = getUsernameMapKey(username);
String result = jedis.get(key);
redisClusterExperiment.compareResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
redisClusterExperiment.compareFutureResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
if (result == null) return Optional.empty();
else return Optional.of(UUID.fromString(result));
@ -165,7 +165,7 @@ public class UsernamesManager {
final String key = getUuidMapKey(uuid);
final String result = jedis.get(key);
redisClusterExperiment.compareResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
redisClusterExperiment.compareFutureResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key)));
return Optional.ofNullable(result);
} catch (JedisException e) {

View File

@ -1,6 +1,5 @@
package org.whispersystems.textsecuregcm.experiment;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Timer;
@ -9,9 +8,7 @@ import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
@ -20,7 +17,6 @@ import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.anyVararg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -37,11 +33,11 @@ public class ExperimentTest {
}
@Test
public void compareResultMatch() {
public void compareFutureResultMatch() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.completedFuture(12));
new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.completedFuture(12));
@SuppressWarnings("unchecked") final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
@ -54,11 +50,11 @@ public class ExperimentTest {
}
@Test
public void compareResultMismatch() {
public void compareFutureResultMismatch() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.completedFuture(77));
new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.completedFuture(77));
@SuppressWarnings("unchecked") final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
@ -71,11 +67,11 @@ public class ExperimentTest {
}
@Test
public void compareResultError() {
public void compareFutureResultError() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.failedFuture(new RuntimeException("OH NO")));
new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.failedFuture(new RuntimeException("OH NO")));
@SuppressWarnings("unchecked") final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
@ -88,11 +84,72 @@ public class ExperimentTest {
}
@Test
public void compareResultNoExperimentData() {
public void compareFutureResultNoExperimentData() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.completedFuture(null));
new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.completedFuture(null));
verify(timer, never()).record(anyLong(), eq(TimeUnit.NANOSECONDS));
}
@Test
public void compareSupplierResultMatch() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> 12);
@SuppressWarnings("unchecked") final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).timer(anyString(), tagCaptor.capture());
final Set<Tag> tags = getTagSet(tagCaptor.getValue());
assertEquals(tags, Set.of(Tag.of(Experiment.OUTCOME_TAG, Experiment.MATCH_OUTCOME)));
verify(timer).record(anyLong(), eq(TimeUnit.NANOSECONDS));
}
@Test
public void compareSupplierResultMismatch() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> 77);
@SuppressWarnings("unchecked") final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).timer(anyString(), tagCaptor.capture());
final Set<Tag> tags = getTagSet(tagCaptor.getValue());
assertEquals(tags, Set.of(Tag.of(Experiment.OUTCOME_TAG, Experiment.MISMATCH_OUTCOME)));
verify(timer).record(anyLong(), eq(TimeUnit.NANOSECONDS));
}
@Test
public void compareSupplierResultError() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> { throw new RuntimeException("OH NO"); });
@SuppressWarnings("unchecked") final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).timer(anyString(), tagCaptor.capture());
final Set<Tag> tags = getTagSet(tagCaptor.getValue());
assertEquals(tags, Set.of(Tag.of(Experiment.OUTCOME_TAG, Experiment.ERROR_OUTCOME), Tag.of(Experiment.CAUSE_TAG, "RuntimeException")));
verify(timer).record(anyLong(), eq(TimeUnit.NANOSECONDS));
}
@Test
public void compareSupplierResultNoExperimentData() {
final Timer timer = mock(Timer.class);
when(meterRegistry.timer(anyString(), ArgumentMatchers.<Iterable<Tag>>any())).thenReturn(timer);
new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> null);
verify(timer, never()).record(anyLong(), eq(TimeUnit.NANOSECONDS));
}