Introduce `ClosableEpoch`

This commit is contained in:
Jon Chambers 2025-04-21 21:31:08 -04:00 committed by Jon Chambers
parent e0ee75e0d0
commit bb8ce6d981
2 changed files with 193 additions and 0 deletions

View File

@ -0,0 +1,92 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import java.util.concurrent.atomic.AtomicInteger;
/**
* A closable epoch is a concurrency construct that measures the number of callers in some critical section. A closable
* epoch can be closed to prevent new callers from entering the critical section, and takes a specific action when the
* critical section is empty after closure.
*/
public class ClosableEpoch {
private final Runnable onCloseHandler;
private final AtomicInteger state = new AtomicInteger();
private static final int CLOSING_BIT_MASK = 0x00000001;
/**
* Constructs a new closable epoch that will execute the given handler when the epoch is closed and all callers have
* departed the critical section. The handler will be executed on the thread that calls {@link #close()} if the
* critical section is empty at the time of the call or on the last thread to call {@link #depart()} otherwise.
* Callers should provide handlers that delegate execution to a specific thread/executor if more precise control over
* which thread runs the handler is required.
*
* @param onCloseHandler a handler to run when the epoch is closed and all callers have departed the critical section
*/
public ClosableEpoch(final Runnable onCloseHandler) {
this.onCloseHandler = onCloseHandler;
}
/**
* Announces the arrival of a caller at the start of the critical section. If the caller is allowed to enter the
* critical section, the epoch's internal caller counter is incremented accordingly.
*
* @return {@code true} if the caller is allowed to enter the critical section or {@code false} if it is not allowed
* to enter the critical section because this epoch is closing
*/
public boolean tryArrive() {
// Increment the number of active callers if and only if we're not closing. We add 2 because the lowest bit encodes
// the "closing" state, and the bits above it encode the actual call count. More verbosely, we're doing
// `state += (1 << 1)` to avoid overwriting the closing state bit.
return !isClosing(state.updateAndGet(state -> isClosing(state) ? state : state + 2));
}
/**
* Announces the departure of a caller from the critical section. If the epoch is closing and the caller is the last
* to depart the critical section, then the epoch will fire its {@code onCloseHandler}.
*/
public void depart() {
// Decrement the active caller count unconditionally. As with `tryActive`, we work in increments of 2 to "dodge" the
// "is closing" bit. If the call count is zero and we're closing then `state` will just have the "closing" bit set.
if (state.addAndGet(-2) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
/**
* Closes this epoch, preventing new callers from entering the critical section. If the critical section is empty when
* this method is called, it will trigger the {@code onCloseHandler} immediately. Otherwise, the
* {@code onCloseHandler} will fire when the last caller departs the critical section.
*
* @throws IllegalStateException if this epoch is already closed; note that this exception is thrown on a
* "best-effort" basis to help callers detect bugs
*/
public void close() {
// Note that this is not airtight and is a "best-effort" check
if (isClosing(state.get())) {
throw new IllegalStateException("Epoch already closed");
}
// Set the "closing" bit. If the closing bit is the only bit set, then the call count is zero and we can call the
// "on close" handler.
if (state.updateAndGet(state -> state | CLOSING_BIT_MASK) == CLOSING_BIT_MASK) {
onCloseHandler.run();
}
}
@VisibleForTesting
int getActiveCallers() {
return state.get() >> 1;
}
private static boolean isClosing(final int state) {
return (state & CLOSING_BIT_MASK) != 0;
}
}

View File

@ -0,0 +1,101 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.jupiter.api.Assertions.*;
class ClosableEpochTest {
@Test
void close() {
{
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
assertTrue(closableEpoch.tryArrive(), "New callers should be allowed to arrive before closure");
assertEquals(1, closableEpoch.getActiveCallers());
closableEpoch.close();
assertFalse(closableEpoch.tryArrive(), "New callers should not be allowed to arrive after closure");
assertEquals(1, closableEpoch.getActiveCallers());
assertFalse(closed.get(), "Close handler should not fire until all callers have departed");
closableEpoch.depart();
assertTrue(closed.get(), "Close handler should fire after last caller departs");
assertEquals(0, closableEpoch.getActiveCallers());
assertThrows(IllegalStateException.class, closableEpoch::close,
"Double-closing a epoch should throw an exception");
}
{
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
closableEpoch.close();
assertTrue(closed.get(), "Empty epoch should fire close handler immediately on closure");
assertEquals(0, closableEpoch.getActiveCallers());
}
}
@Test
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
void closeConcurrent() throws InterruptedException {
final AtomicBoolean closed = new AtomicBoolean(false);
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> {
synchronized (closed) {
closed.set(true);
closed.notifyAll();
}
});
final int threadCount = 128;
final CyclicBarrier cyclicBarrier = new CyclicBarrier(threadCount);
// Spawn a bunch of threads doing some simulated work. Close the epoch roughly halfway through. Some threads should
// successfully enter the critical section and others should be rejected.
for (int t = 0; t < threadCount; t++) {
final boolean shouldClose = t == threadCount / 2;
Thread.ofVirtual().start(() -> {
try {
// Wait for all threads to reach the proverbial starting line
cyclicBarrier.await();
} catch (final InterruptedException | BrokenBarrierException ignored) {
}
if (shouldClose) {
closableEpoch.close();
}
if (closableEpoch.tryArrive()) {
// Perform some simulated "work"
try {
Thread.sleep(1);
} catch (final InterruptedException ignored) {
} finally {
closableEpoch.depart();
}
}
});
}
while (!closed.get()) {
synchronized (closed) {
closed.wait();
}
}
assertEquals(0, closableEpoch.getActiveCallers());
}
}