Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

THRIFT-5774: Add remote client's IP address to ServerContext in TServ… #2959

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 16 additions & 3 deletions lib/java/src/crossTest/java/org/apache/thrift/test/TestServer.java
Expand Up @@ -19,6 +19,7 @@

package org.apache.thrift.test;

import java.net.SocketAddress;
import org.apache.thrift.TMultiplexedProcessor;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
Expand Down Expand Up @@ -69,8 +70,11 @@ static class TestServerContext implements ServerContext {

int connectionId;

public TestServerContext(int connectionId) {
SocketAddress remoteAddress;

public TestServerContext(int connectionId, SocketAddress remoteAddress) {
this.connectionId = connectionId;
this.remoteAddress = remoteAddress;
}

public int getConnectionId() {
Expand All @@ -81,6 +85,14 @@ public void setConnectionId(int connectionId) {
this.connectionId = connectionId;
}

public SocketAddress getRemoteAddress() {
return remoteAddress;
}

public void setRemoteAddress(SocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}

@Override
public <T> T unwrap(Class<T> iface) {
try {
Expand Down Expand Up @@ -110,9 +122,10 @@ public void preServe() {
"TServerEventHandler.preServe - called only once before server starts accepting connections");
}

public ServerContext createContext(TProtocol input, TProtocol output) {
public ServerContext createContext(
TProtocol input, TProtocol output, SocketAddress remoteAddress) {
// we can create some connection level data which is stored while connection is alive & served
TestServerContext ctx = new TestServerContext(nextConnectionId++);
TestServerContext ctx = new TestServerContext(nextConnectionId++, remoteAddress);
System.out.println(
"TServerEventHandler.createContext - connection #"
+ ctx.getConnectionId()
Expand Down
Expand Up @@ -20,6 +20,7 @@
package org.apache.thrift.server;

import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
Expand All @@ -34,6 +35,7 @@
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TNonblockingServerTransport;
import org.apache.thrift.transport.TNonblockingSocket;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
Expand Down Expand Up @@ -296,7 +298,9 @@ public FrameBuffer(
outProt_ = outputProtocolFactory_.getProtocol(outTrans_);

if (eventHandler_ != null) {
context_ = eventHandler_.createContext(inProt_, outProt_);
TNonblockingSocket socket = (TNonblockingSocket) trans_;
SocketAddress remoteAddress = socket.getSocketChannel().socket().getRemoteSocketAddress();
context_ = eventHandler_.createContext(inProt_, outProt_, remoteAddress);
} else {
context_ = null;
}
Expand Down
Expand Up @@ -19,6 +19,7 @@

package org.apache.thrift.server;

import java.net.SocketAddress;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TTransport;

Expand All @@ -38,7 +39,7 @@ public interface TServerEventHandler {
void preServe();

/** Called when a new client has connected and is about to being processing. */
ServerContext createContext(TProtocol input, TProtocol output);
ServerContext createContext(TProtocol input, TProtocol output, SocketAddress remoteAddress);

/** Called when a client has finished request-handling to delete server context. */
void deleteContext(ServerContext serverContext, TProtocol input, TProtocol output);
Expand Down
Expand Up @@ -19,9 +19,11 @@

package org.apache.thrift.server;

import java.net.SocketAddress;
import org.apache.thrift.TException;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
Expand Down Expand Up @@ -69,7 +71,10 @@ public void serve() {
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
if (eventHandler_ != null) {
connectionContext = eventHandler_.createContext(inputProtocol, outputProtocol);
TSocket socket = (TSocket) client;
SocketAddress remoteAddress = socket.getSocket().getRemoteSocketAddress();
connectionContext =
eventHandler_.createContext(inputProtocol, outputProtocol, remoteAddress);
}
while (true) {
if (eventHandler_ != null) {
Expand Down
Expand Up @@ -19,6 +19,7 @@

package org.apache.thrift.server;

import java.net.SocketAddress;
import java.net.SocketException;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
Expand All @@ -32,6 +33,7 @@
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TServerTransport;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
Expand Down Expand Up @@ -239,7 +241,10 @@ public void run() {
eventHandler = Optional.ofNullable(getEventHandler());

if (eventHandler.isPresent()) {
connectionContext = eventHandler.get().createContext(inputProtocol, outputProtocol);
TSocket socket = (TSocket) client_;
SocketAddress remoteAddress = socket.getSocket().getRemoteSocketAddress();
connectionContext =
eventHandler.get().createContext(inputProtocol, outputProtocol, remoteAddress);
}

while (true) {
Expand Down
Expand Up @@ -22,6 +22,7 @@
import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE;
import static org.apache.thrift.transport.sasl.NegotiationStatus.OK;

import java.net.SocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.charset.StandardCharsets;
import javax.security.sasl.SaslServer;
Expand All @@ -32,6 +33,7 @@
import org.apache.thrift.server.ServerContext;
import org.apache.thrift.server.TServerEventHandler;
import org.apache.thrift.transport.TMemoryTransport;
import org.apache.thrift.transport.TNonblockingSocket;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType;
Expand Down Expand Up @@ -324,7 +326,10 @@ private void executeProcessing() {

if (eventHandler != null) {
if (!serverContextCreated) {
serverContext = eventHandler.createContext(requestProtocol, responseProtocol);
TNonblockingSocket socket = (TNonblockingSocket) underlyingTransport;
SocketAddress remoteAddress = socket.getSocketChannel().socket().getRemoteSocketAddress();
serverContext =
eventHandler.createContext(requestProtocol, responseProtocol, remoteAddress);
serverContextCreated = true;
}
eventHandler.processContext(serverContext, memoryTransport, memoryTransport);
Expand Down
Expand Up @@ -26,6 +26,7 @@ import com.github.ajalt.clikt.parameters.options.option
import com.github.ajalt.clikt.parameters.types.enum
import com.github.ajalt.clikt.parameters.types.int
import com.github.ajalt.clikt.parameters.types.long
import java.net.SocketAddress
import kotlinx.coroutines.GlobalScope
import org.apache.thrift.TException
import org.apache.thrift.TMultiplexedProcessor
Expand Down Expand Up @@ -73,7 +74,8 @@ object TestServer {
}
}

internal class TestServerContext(var connectionId: Int) : ServerContext {
internal class TestServerContext(var connectionId: Int, var remoteAddress: SocketAddress) :
ServerContext {

override fun <T> unwrap(iface: Class<T>): T {
try {
Expand Down Expand Up @@ -102,10 +104,14 @@ object TestServer {
)
}

override fun createContext(input: TProtocol, output: TProtocol): ServerContext {
override fun createContext(
input: TProtocol,
output: TProtocol,
remoteAddress: SocketAddress
): ServerContext {
// we can create some connection level data which is stored while connection is alive &
// served
val ctx = TestServerContext(nextConnectionId++)
val ctx = TestServerContext(nextConnectionId++, remoteAddress)
println(
"TServerEventHandler.createContext - connection #" +
ctx.connectionId +
Expand Down