diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ClosableEpoch.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ClosableEpoch.java new file mode 100644 index 000000000..71308c9ba --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/ClosableEpoch.java @@ -0,0 +1,92 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import com.google.common.annotations.VisibleForTesting; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A closable epoch is a concurrency construct that measures the number of callers in some critical section. A closable + * epoch can be closed to prevent new callers from entering the critical section, and takes a specific action when the + * critical section is empty after closure. + */ +public class ClosableEpoch { + + private final Runnable onCloseHandler; + + private final AtomicInteger state = new AtomicInteger(); + + private static final int CLOSING_BIT_MASK = 0x00000001; + + /** + * Constructs a new closable epoch that will execute the given handler when the epoch is closed and all callers have + * departed the critical section. The handler will be executed on the thread that calls {@link #close()} if the + * critical section is empty at the time of the call or on the last thread to call {@link #depart()} otherwise. + * Callers should provide handlers that delegate execution to a specific thread/executor if more precise control over + * which thread runs the handler is required. + * + * @param onCloseHandler a handler to run when the epoch is closed and all callers have departed the critical section + */ + public ClosableEpoch(final Runnable onCloseHandler) { + this.onCloseHandler = onCloseHandler; + } + + /** + * Announces the arrival of a caller at the start of the critical section. If the caller is allowed to enter the + * critical section, the epoch's internal caller counter is incremented accordingly. + * + * @return {@code true} if the caller is allowed to enter the critical section or {@code false} if it is not allowed + * to enter the critical section because this epoch is closing + */ + public boolean tryArrive() { + // Increment the number of active callers if and only if we're not closing. We add 2 because the lowest bit encodes + // the "closing" state, and the bits above it encode the actual call count. More verbosely, we're doing + // `state += (1 << 1)` to avoid overwriting the closing state bit. + return !isClosing(state.updateAndGet(state -> isClosing(state) ? state : state + 2)); + } + + /** + * Announces the departure of a caller from the critical section. If the epoch is closing and the caller is the last + * to depart the critical section, then the epoch will fire its {@code onCloseHandler}. + */ + public void depart() { + // Decrement the active caller count unconditionally. As with `tryActive`, we work in increments of 2 to "dodge" the + // "is closing" bit. If the call count is zero and we're closing then `state` will just have the "closing" bit set. + if (state.addAndGet(-2) == CLOSING_BIT_MASK) { + onCloseHandler.run(); + } + } + + /** + * Closes this epoch, preventing new callers from entering the critical section. If the critical section is empty when + * this method is called, it will trigger the {@code onCloseHandler} immediately. Otherwise, the + * {@code onCloseHandler} will fire when the last caller departs the critical section. + * + * @throws IllegalStateException if this epoch is already closed; note that this exception is thrown on a + * "best-effort" basis to help callers detect bugs + */ + public void close() { + // Note that this is not airtight and is a "best-effort" check + if (isClosing(state.get())) { + throw new IllegalStateException("Epoch already closed"); + } + + // Set the "closing" bit. If the closing bit is the only bit set, then the call count is zero and we can call the + // "on close" handler. + if (state.updateAndGet(state -> state | CLOSING_BIT_MASK) == CLOSING_BIT_MASK) { + onCloseHandler.run(); + } + } + + @VisibleForTesting + int getActiveCallers() { + return state.get() >> 1; + } + + private static boolean isClosing(final int state) { + return (state & CLOSING_BIT_MASK) != 0; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/ClosableEpochTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/ClosableEpochTest.java new file mode 100644 index 000000000..cf0b2fdc6 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/ClosableEpochTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.jupiter.api.Assertions.*; + +class ClosableEpochTest { + + @Test + void close() { + { + final AtomicBoolean closed = new AtomicBoolean(false); + final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true)); + + assertTrue(closableEpoch.tryArrive(), "New callers should be allowed to arrive before closure"); + assertEquals(1, closableEpoch.getActiveCallers()); + + closableEpoch.close(); + assertFalse(closableEpoch.tryArrive(), "New callers should not be allowed to arrive after closure"); + assertEquals(1, closableEpoch.getActiveCallers()); + assertFalse(closed.get(), "Close handler should not fire until all callers have departed"); + + closableEpoch.depart(); + assertTrue(closed.get(), "Close handler should fire after last caller departs"); + assertEquals(0, closableEpoch.getActiveCallers()); + + assertThrows(IllegalStateException.class, closableEpoch::close, + "Double-closing a epoch should throw an exception"); + } + + { + final AtomicBoolean closed = new AtomicBoolean(false); + final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true)); + + closableEpoch.close(); + assertTrue(closed.get(), "Empty epoch should fire close handler immediately on closure"); + assertEquals(0, closableEpoch.getActiveCallers()); + } + } + + @Test + @Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) + void closeConcurrent() throws InterruptedException { + final AtomicBoolean closed = new AtomicBoolean(false); + final ClosableEpoch closableEpoch = new ClosableEpoch(() -> { + synchronized (closed) { + closed.set(true); + closed.notifyAll(); + } + }); + + final int threadCount = 128; + final CyclicBarrier cyclicBarrier = new CyclicBarrier(threadCount); + + // Spawn a bunch of threads doing some simulated work. Close the epoch roughly halfway through. Some threads should + // successfully enter the critical section and others should be rejected. + for (int t = 0; t < threadCount; t++) { + final boolean shouldClose = t == threadCount / 2; + + Thread.ofVirtual().start(() -> { + try { + // Wait for all threads to reach the proverbial starting line + cyclicBarrier.await(); + } catch (final InterruptedException | BrokenBarrierException ignored) { + } + + if (shouldClose) { + closableEpoch.close(); + } + + if (closableEpoch.tryArrive()) { + // Perform some simulated "work" + try { + Thread.sleep(1); + } catch (final InterruptedException ignored) { + } finally { + closableEpoch.depart(); + } + } + }); + } + + while (!closed.get()) { + synchronized (closed) { + closed.wait(); + } + } + + assertEquals(0, closableEpoch.getActiveCallers()); + } +}