diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java index ce733fbc9..8fa8fd9c9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/Experiment.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Objects; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; /** * An experiment compares the results of two operations and records metrics to assess how frequently they match. @@ -40,7 +41,7 @@ public class Experiment { this.meterRegistry = meterRegistry; } - public void compareResult(final T expected, final CompletionStage experimentStage) { + public void compareFutureResult(final T expected, final CompletionStage experimentStage) { // We're assuming that we get the experiment completion stage as soon as possible after it's started, but this // start time will inescapably have some "wiggle" to it. final long start = System.nanoTime(); @@ -49,18 +50,37 @@ public class Experiment { final long duration = System.nanoTime() - start; if (cause != null) { - meterRegistry.timer(timerName, - List.of(Tag.of(OUTCOME_TAG, ERROR_OUTCOME), Tag.of(CAUSE_TAG, cause.getClass().getSimpleName()))) - .record(duration, TimeUnit.NANOSECONDS); + recordError(duration, cause); } else { - final boolean shouldIgnore = actual == null && expected != null; - - if (!shouldIgnore) { - meterRegistry.timer(timerName, - List.of(Tag.of(OUTCOME_TAG, Objects.equals(expected, actual) ? MATCH_OUTCOME : MISMATCH_OUTCOME))) - .record(duration, TimeUnit.NANOSECONDS); - } + recordResult(duration, expected, actual); } }); } + + public void compareSupplierResult(final T expected, final Supplier experimentSupplier) { + final long start = System.nanoTime(); + + try { + final T result = experimentSupplier.get(); + recordResult(System.nanoTime() - start, expected, result); + } catch (final Exception e) { + recordError(System.nanoTime() - start, e); + } + } + + private void recordError(final long durationNanos, final Throwable cause) { + meterRegistry.timer(timerName, + List.of(Tag.of(OUTCOME_TAG, ERROR_OUTCOME), Tag.of(CAUSE_TAG, cause.getClass().getSimpleName()))) + .record(durationNanos, TimeUnit.NANOSECONDS); + } + + private void recordResult(final long durationNanos, final T expected, final T actual) { + final boolean shouldIgnore = actual == null && expected != null; + + if (!shouldIgnore) { + meterRegistry.timer(timerName, + List.of(Tag.of(OUTCOME_TAG, Objects.equals(expected, actual) ? MATCH_OUTCOME : MISMATCH_OUTCOME))) + .record(durationNanos, TimeUnit.NANOSECONDS); + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 366cc7e2a..3f90aa3c3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -113,7 +113,7 @@ public class RateLimiter { final String bucketName = getBucketName(key); String serialized = jedis.get(bucketName); - redisClusterExperiment.compareResult(serialized, cacheCluster.withReadCluster(connection -> connection.async().get(bucketName))); + redisClusterExperiment.compareFutureResult(serialized, cacheCluster.withReadCluster(connection -> connection.async().get(bucketName))); if (serialized != null) { return LeakyBucket.fromSerialized(mapper, serialized); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerCache.java index e71013c7e..9dc0ad57b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerCache.java @@ -60,7 +60,7 @@ public class AccountDatabaseCrawlerCache { public boolean isAccelerated() { try (Jedis jedis = jedisPool.getWriteResource()) { final String accelerated = jedis.get(ACCELERATE_KEY); - redisClusterExperiment.compareResult(accelerated, cacheCluster.withReadCluster(connection -> connection.async().get(ACCELERATE_KEY))); + redisClusterExperiment.compareFutureResult(accelerated, cacheCluster.withReadCluster(connection -> connection.async().get(ACCELERATE_KEY))); return "1".equals(accelerated); } @@ -88,7 +88,7 @@ public class AccountDatabaseCrawlerCache { public Optional getLastUuid() { try (Jedis jedis = jedisPool.getWriteResource()) { String lastUuidString = jedis.get(LAST_UUID_KEY); - redisClusterExperiment.compareResult(lastUuidString, cacheCluster.withWriteCluster(connection -> connection.async().get(LAST_UUID_KEY))); + redisClusterExperiment.compareFutureResult(lastUuidString, cacheCluster.withWriteCluster(connection -> connection.async().get(LAST_UUID_KEY))); if (lastUuidString == null) return Optional.empty(); else return Optional.of(UUID.fromString(lastUuidString)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index f14255297..4476e2d16 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -179,7 +179,7 @@ public class AccountsManager { final String key = getAccountMapKey(number); String uuid = jedis.get(key); - redisClusterExperiment.compareResult(uuid, cacheCluster.withReadCluster(connection -> connection.async().get(key))); + redisClusterExperiment.compareFutureResult(uuid, cacheCluster.withReadCluster(connection -> connection.async().get(key))); if (uuid != null) return redisGet(jedis, UUID.fromString(uuid)); else return Optional.empty(); @@ -203,7 +203,7 @@ public class AccountsManager { final String key = getAccountEntityKey(uuid); String json = jedis.get(key); - redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key))); + redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key))); if (json != null) { Account account = mapper.readValue(json, Account.class); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ActiveUserCounter.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ActiveUserCounter.java index fd7753da3..00cd9f99b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ActiveUserCounter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ActiveUserCounter.java @@ -165,7 +165,7 @@ public class ActiveUserCounter extends AccountDatabaseCrawlerListener { private void incrementTallies(UUID fromUuid, Map platformIncrements, Map countryIncrements) { try (Jedis jedis = jedisPool.getWriteResource()) { String tallyValue = jedis.get(TALLY_KEY); - redisClusterExperiment.compareResult(tallyValue, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY))); + redisClusterExperiment.compareFutureResult(tallyValue, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY))); ActiveUserTally activeUserTally; @@ -210,7 +210,7 @@ public class ActiveUserCounter extends AccountDatabaseCrawlerListener { private ActiveUserTally getFinalTallies() { try (Jedis jedis = jedisPool.getReadResource()) { final String tallyJson = jedis.get(TALLY_KEY); - redisClusterExperiment.compareResult(tallyJson, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY))); + redisClusterExperiment.compareFutureResult(tallyJson, cacheCluster.withReadCluster(connection -> connection.async().get(TALLY_KEY))); return mapper.readValue(tallyJson, ActiveUserTally.class); } catch (IOException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java index baf220eee..290bd8f0f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingAccountsManager.java @@ -89,7 +89,7 @@ public class PendingAccountsManager { final String key = CACHE_PREFIX + number; String json = jedis.get(key); - redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key))); + redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key))); if (json == null) return Optional.empty(); else return Optional.of(mapper.readValue(json, StoredVerificationCode.class)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java index 233cc20dd..e023a2ec7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PendingDevicesManager.java @@ -88,7 +88,7 @@ public class PendingDevicesManager { final String key = CACHE_PREFIX + number; String json = jedis.get(key); - redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key))); + redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().get(key))); if (json == null) return Optional.empty(); else return Optional.of(mapper.readValue(json, StoredVerificationCode.class)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java index fcdf2dea6..a30619613 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java @@ -74,7 +74,7 @@ public class ProfilesManager { final String key = CACHE_PREFIX + uuid.toString(); String json = jedis.hget(key, version); - redisClusterExperiment.compareResult(json, cacheCluster.withReadCluster(connection -> connection.async().hget(key, version))); + redisClusterExperiment.compareFutureResult(json, cacheCluster.withReadCluster(connection -> connection.async().hget(key, version))); if (json == null) return Optional.empty(); else return Optional.of(mapper.readValue(json, VersionedProfile.class)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/UsernamesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UsernamesManager.java index 1fbfb97ce..78d1f320f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/UsernamesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UsernamesManager.java @@ -148,7 +148,7 @@ public class UsernamesManager { final String key = getUsernameMapKey(username); String result = jedis.get(key); - redisClusterExperiment.compareResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key))); + redisClusterExperiment.compareFutureResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key))); if (result == null) return Optional.empty(); else return Optional.of(UUID.fromString(result)); @@ -165,7 +165,7 @@ public class UsernamesManager { final String key = getUuidMapKey(uuid); final String result = jedis.get(key); - redisClusterExperiment.compareResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key))); + redisClusterExperiment.compareFutureResult(result, cacheCluster.withReadCluster(connection -> connection.async().get(key))); return Optional.ofNullable(result); } catch (JedisException e) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentTest.java index 1b60f2fa6..2cff0151f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentTest.java @@ -1,6 +1,5 @@ package org.whispersystems.textsecuregcm.experiment; -import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Timer; @@ -9,9 +8,7 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; -import java.util.ArrayList; import java.util.HashSet; -import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -20,7 +17,6 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.anyVararg; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -37,11 +33,11 @@ public class ExperimentTest { } @Test - public void compareResultMatch() { + public void compareFutureResultMatch() { final Timer timer = mock(Timer.class); when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); - new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.completedFuture(12)); + new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.completedFuture(12)); @SuppressWarnings("unchecked") final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); @@ -54,11 +50,11 @@ public class ExperimentTest { } @Test - public void compareResultMismatch() { + public void compareFutureResultMismatch() { final Timer timer = mock(Timer.class); when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); - new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.completedFuture(77)); + new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.completedFuture(77)); @SuppressWarnings("unchecked") final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); @@ -71,11 +67,11 @@ public class ExperimentTest { } @Test - public void compareResultError() { + public void compareFutureResultError() { final Timer timer = mock(Timer.class); when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); - new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.failedFuture(new RuntimeException("OH NO"))); + new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.failedFuture(new RuntimeException("OH NO"))); @SuppressWarnings("unchecked") final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); @@ -88,11 +84,72 @@ public class ExperimentTest { } @Test - public void compareResultNoExperimentData() { + public void compareFutureResultNoExperimentData() { final Timer timer = mock(Timer.class); when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); - new Experiment(meterRegistry, "test").compareResult(12, CompletableFuture.completedFuture(null)); + new Experiment(meterRegistry, "test").compareFutureResult(12, CompletableFuture.completedFuture(null)); + + verify(timer, never()).record(anyLong(), eq(TimeUnit.NANOSECONDS)); + } + + @Test + public void compareSupplierResultMatch() { + final Timer timer = mock(Timer.class); + when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); + + new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> 12); + + @SuppressWarnings("unchecked") final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + + verify(meterRegistry).timer(anyString(), tagCaptor.capture()); + + final Set tags = getTagSet(tagCaptor.getValue()); + assertEquals(tags, Set.of(Tag.of(Experiment.OUTCOME_TAG, Experiment.MATCH_OUTCOME))); + + verify(timer).record(anyLong(), eq(TimeUnit.NANOSECONDS)); + } + + @Test + public void compareSupplierResultMismatch() { + final Timer timer = mock(Timer.class); + when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); + + new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> 77); + + @SuppressWarnings("unchecked") final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + + verify(meterRegistry).timer(anyString(), tagCaptor.capture()); + + final Set tags = getTagSet(tagCaptor.getValue()); + assertEquals(tags, Set.of(Tag.of(Experiment.OUTCOME_TAG, Experiment.MISMATCH_OUTCOME))); + + verify(timer).record(anyLong(), eq(TimeUnit.NANOSECONDS)); + } + + @Test + public void compareSupplierResultError() { + final Timer timer = mock(Timer.class); + when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); + + new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> { throw new RuntimeException("OH NO"); }); + + @SuppressWarnings("unchecked") final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + + verify(meterRegistry).timer(anyString(), tagCaptor.capture()); + + final Set tags = getTagSet(tagCaptor.getValue()); + assertEquals(tags, Set.of(Tag.of(Experiment.OUTCOME_TAG, Experiment.ERROR_OUTCOME), Tag.of(Experiment.CAUSE_TAG, "RuntimeException"))); + + verify(timer).record(anyLong(), eq(TimeUnit.NANOSECONDS)); + } + + @Test + public void compareSupplierResultNoExperimentData() { + final Timer timer = mock(Timer.class); + when(meterRegistry.timer(anyString(), ArgumentMatchers.>any())).thenReturn(timer); + + new Experiment(meterRegistry, "test").compareSupplierResult(12, () -> null); verify(timer, never()).record(anyLong(), eq(TimeUnit.NANOSECONDS)); }