Introduce `ClosableEpoch`
This commit is contained in:
parent
e0ee75e0d0
commit
bb8ce6d981
|
@ -0,0 +1,92 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* A closable epoch is a concurrency construct that measures the number of callers in some critical section. A closable
|
||||
* epoch can be closed to prevent new callers from entering the critical section, and takes a specific action when the
|
||||
* critical section is empty after closure.
|
||||
*/
|
||||
public class ClosableEpoch {
|
||||
|
||||
private final Runnable onCloseHandler;
|
||||
|
||||
private final AtomicInteger state = new AtomicInteger();
|
||||
|
||||
private static final int CLOSING_BIT_MASK = 0x00000001;
|
||||
|
||||
/**
|
||||
* Constructs a new closable epoch that will execute the given handler when the epoch is closed and all callers have
|
||||
* departed the critical section. The handler will be executed on the thread that calls {@link #close()} if the
|
||||
* critical section is empty at the time of the call or on the last thread to call {@link #depart()} otherwise.
|
||||
* Callers should provide handlers that delegate execution to a specific thread/executor if more precise control over
|
||||
* which thread runs the handler is required.
|
||||
*
|
||||
* @param onCloseHandler a handler to run when the epoch is closed and all callers have departed the critical section
|
||||
*/
|
||||
public ClosableEpoch(final Runnable onCloseHandler) {
|
||||
this.onCloseHandler = onCloseHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Announces the arrival of a caller at the start of the critical section. If the caller is allowed to enter the
|
||||
* critical section, the epoch's internal caller counter is incremented accordingly.
|
||||
*
|
||||
* @return {@code true} if the caller is allowed to enter the critical section or {@code false} if it is not allowed
|
||||
* to enter the critical section because this epoch is closing
|
||||
*/
|
||||
public boolean tryArrive() {
|
||||
// Increment the number of active callers if and only if we're not closing. We add 2 because the lowest bit encodes
|
||||
// the "closing" state, and the bits above it encode the actual call count. More verbosely, we're doing
|
||||
// `state += (1 << 1)` to avoid overwriting the closing state bit.
|
||||
return !isClosing(state.updateAndGet(state -> isClosing(state) ? state : state + 2));
|
||||
}
|
||||
|
||||
/**
|
||||
* Announces the departure of a caller from the critical section. If the epoch is closing and the caller is the last
|
||||
* to depart the critical section, then the epoch will fire its {@code onCloseHandler}.
|
||||
*/
|
||||
public void depart() {
|
||||
// Decrement the active caller count unconditionally. As with `tryActive`, we work in increments of 2 to "dodge" the
|
||||
// "is closing" bit. If the call count is zero and we're closing then `state` will just have the "closing" bit set.
|
||||
if (state.addAndGet(-2) == CLOSING_BIT_MASK) {
|
||||
onCloseHandler.run();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes this epoch, preventing new callers from entering the critical section. If the critical section is empty when
|
||||
* this method is called, it will trigger the {@code onCloseHandler} immediately. Otherwise, the
|
||||
* {@code onCloseHandler} will fire when the last caller departs the critical section.
|
||||
*
|
||||
* @throws IllegalStateException if this epoch is already closed; note that this exception is thrown on a
|
||||
* "best-effort" basis to help callers detect bugs
|
||||
*/
|
||||
public void close() {
|
||||
// Note that this is not airtight and is a "best-effort" check
|
||||
if (isClosing(state.get())) {
|
||||
throw new IllegalStateException("Epoch already closed");
|
||||
}
|
||||
|
||||
// Set the "closing" bit. If the closing bit is the only bit set, then the call count is zero and we can call the
|
||||
// "on close" handler.
|
||||
if (state.updateAndGet(state -> state | CLOSING_BIT_MASK) == CLOSING_BIT_MASK) {
|
||||
onCloseHandler.run();
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
int getActiveCallers() {
|
||||
return state.get() >> 1;
|
||||
}
|
||||
|
||||
private static boolean isClosing(final int state) {
|
||||
return (state & CLOSING_BIT_MASK) != 0;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
|
||||
import java.util.concurrent.BrokenBarrierException;
|
||||
import java.util.concurrent.CyclicBarrier;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class ClosableEpochTest {
|
||||
|
||||
@Test
|
||||
void close() {
|
||||
{
|
||||
final AtomicBoolean closed = new AtomicBoolean(false);
|
||||
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
|
||||
|
||||
assertTrue(closableEpoch.tryArrive(), "New callers should be allowed to arrive before closure");
|
||||
assertEquals(1, closableEpoch.getActiveCallers());
|
||||
|
||||
closableEpoch.close();
|
||||
assertFalse(closableEpoch.tryArrive(), "New callers should not be allowed to arrive after closure");
|
||||
assertEquals(1, closableEpoch.getActiveCallers());
|
||||
assertFalse(closed.get(), "Close handler should not fire until all callers have departed");
|
||||
|
||||
closableEpoch.depart();
|
||||
assertTrue(closed.get(), "Close handler should fire after last caller departs");
|
||||
assertEquals(0, closableEpoch.getActiveCallers());
|
||||
|
||||
assertThrows(IllegalStateException.class, closableEpoch::close,
|
||||
"Double-closing a epoch should throw an exception");
|
||||
}
|
||||
|
||||
{
|
||||
final AtomicBoolean closed = new AtomicBoolean(false);
|
||||
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> closed.set(true));
|
||||
|
||||
closableEpoch.close();
|
||||
assertTrue(closed.get(), "Empty epoch should fire close handler immediately on closure");
|
||||
assertEquals(0, closableEpoch.getActiveCallers());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
|
||||
void closeConcurrent() throws InterruptedException {
|
||||
final AtomicBoolean closed = new AtomicBoolean(false);
|
||||
final ClosableEpoch closableEpoch = new ClosableEpoch(() -> {
|
||||
synchronized (closed) {
|
||||
closed.set(true);
|
||||
closed.notifyAll();
|
||||
}
|
||||
});
|
||||
|
||||
final int threadCount = 128;
|
||||
final CyclicBarrier cyclicBarrier = new CyclicBarrier(threadCount);
|
||||
|
||||
// Spawn a bunch of threads doing some simulated work. Close the epoch roughly halfway through. Some threads should
|
||||
// successfully enter the critical section and others should be rejected.
|
||||
for (int t = 0; t < threadCount; t++) {
|
||||
final boolean shouldClose = t == threadCount / 2;
|
||||
|
||||
Thread.ofVirtual().start(() -> {
|
||||
try {
|
||||
// Wait for all threads to reach the proverbial starting line
|
||||
cyclicBarrier.await();
|
||||
} catch (final InterruptedException | BrokenBarrierException ignored) {
|
||||
}
|
||||
|
||||
if (shouldClose) {
|
||||
closableEpoch.close();
|
||||
}
|
||||
|
||||
if (closableEpoch.tryArrive()) {
|
||||
// Perform some simulated "work"
|
||||
try {
|
||||
Thread.sleep(1);
|
||||
} catch (final InterruptedException ignored) {
|
||||
} finally {
|
||||
closableEpoch.depart();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
while (!closed.get()) {
|
||||
synchronized (closed) {
|
||||
closed.wait();
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(0, closableEpoch.getActiveCallers());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue