diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java new file mode 100644 index 000000000..f981f094d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcService.java @@ -0,0 +1,44 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import org.signal.chat.calling.GetTurnCredentialsRequest; +import org.signal.chat.calling.GetTurnCredentialsResponse; +import org.signal.chat.calling.ReactorCallingGrpc; +import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import reactor.core.publisher.Mono; + +public class CallingGrpcService extends ReactorCallingGrpc.CallingImplBase { + + private final TurnTokenGenerator turnTokenGenerator; + private final RateLimiters rateLimiters; + + public CallingGrpcService(final TurnTokenGenerator turnTokenGenerator, final RateLimiters rateLimiters) { + this.turnTokenGenerator = turnTokenGenerator; + this.rateLimiters = rateLimiters; + } + + @Override + protected Throwable onErrorMap(final Throwable throwable) { + return RateLimitUtil.mapRateLimitExceededException(throwable); + } + + @Override + public Mono getTurnCredentials(final GetTurnCredentialsRequest request) { + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + + return rateLimiters.getTurnLimiter().validateReactive(authenticatedDevice.accountIdentifier()) + .then(Mono.fromSupplier(() -> turnTokenGenerator.generate(authenticatedDevice.accountIdentifier()))) + .map(turnToken -> GetTurnCredentialsResponse.newBuilder() + .setUsername(turnToken.username()) + .setPassword(turnToken.password()) + .addAllUrls(turnToken.urls()) + .build()); + } +} diff --git a/service/src/main/proto/org/signal/chat/calling.proto b/service/src/main/proto/org/signal/chat/calling.proto new file mode 100644 index 000000000..0330dc4b9 --- /dev/null +++ b/service/src/main/proto/org/signal/chat/calling.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +option java_multiple_files = true; + +package org.signal.chat.calling; + +/** + * Provides methods for getting credentials for one-on-one and group calls. + */ +service Calling { + + /** + * Generates and returns TURN credentials for the caller. + * + * This RPC may fail with a `RESOURCE_EXHAUSTED` status if a rate limit for + * generating TURN credentials has been exceeded, in which case a + * `retry-after` header containing an ISO 8601 duration string will be present + * in the response trailers. + */ + rpc GetTurnCredentials(GetTurnCredentialsRequest) returns (GetTurnCredentialsResponse) {} +} + +message GetTurnCredentialsRequest {} + +message GetTurnCredentialsResponse { + /** + * A username that can be presented to authenticate with a TURN server. + */ + string username = 1; + + /** + * A password that can be presented to authenticate with a TURN server. + */ + string password = 2; + + /** + * A list of TURN (or TURNS or STUN) servers where the provided credentials + * may be used. + */ + repeated string urls = 3; +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java new file mode 100644 index 000000000..122448fd6 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallingGrpcServiceTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.ServerInterceptors; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.time.Duration; +import java.util.List; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.signal.chat.calling.CallingGrpc; +import org.signal.chat.calling.GetTurnCredentialsRequest; +import org.signal.chat.calling.GetTurnCredentialsResponse; +import org.whispersystems.textsecuregcm.auth.TurnToken; +import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; +import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.storage.Device; +import reactor.core.publisher.Mono; + +class CallingGrpcServiceTest { + + private TurnTokenGenerator turnTokenGenerator; + private RateLimiter turnCredentialRateLimiter; + + private CallingGrpc.CallingBlockingStub callingStub; + + private static final UUID AUTHENTICATED_ACI = UUID.randomUUID(); + private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID; + + @RegisterExtension + static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); + + @BeforeEach + void setUp() { + turnTokenGenerator = mock(TurnTokenGenerator.class); + turnCredentialRateLimiter = mock(RateLimiter.class); + + final RateLimiters rateLimiters = mock(RateLimiters.class); + when(rateLimiters.getTurnLimiter()).thenReturn(turnCredentialRateLimiter); + + final CallingGrpcService callingGrpcService = new CallingGrpcService(turnTokenGenerator, rateLimiters); + + final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); + mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID); + + GRPC_SERVER_EXTENSION.getServiceRegistry() + .addService(ServerInterceptors.intercept(callingGrpcService, mockAuthenticationInterceptor)); + + callingStub = CallingGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()); + } + + @Test + void getTurnCredentials() { + final String username = "test-username"; + final String password = "test-password"; + final List urls = List.of("first", "second"); + + when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)).thenReturn(Mono.empty()); + when(turnTokenGenerator.generate(any())).thenReturn(new TurnToken(username, password, urls)); + + final GetTurnCredentialsResponse response = callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build()); + + final GetTurnCredentialsResponse expectedResponse = GetTurnCredentialsResponse.newBuilder() + .setUsername(username) + .setPassword(password) + .addAllUrls(urls) + .build(); + + assertEquals(expectedResponse, response); + } + + @Test + void getTurnCredentialsRateLimited() { + final Duration retryAfter = Duration.ofMinutes(19); + + when(turnCredentialRateLimiter.validateReactive(AUTHENTICATED_ACI)) + .thenReturn(Mono.error(new RateLimitExceededException(retryAfter, false))); + + final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> callingStub.getTurnCredentials(GetTurnCredentialsRequest.newBuilder().build())); + + verify(turnTokenGenerator, never()).generate(any()); + + assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode()); + assertNotNull(exception.getTrailers()); + assertEquals(retryAfter, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY)); + } +}