Skip to content

Commit

Permalink
Merge pull request #7 from nats-io/1.1.0
Browse files Browse the repository at this point in the history
[ADDED] V1.1 Support for TLS authentication
  • Loading branch information
jnmoyne committed Sep 13, 2023
2 parents 05d7825 + f004b17 commit 044e54c
Show file tree
Hide file tree
Showing 9 changed files with 1,403 additions and 63 deletions.
2 changes: 1 addition & 1 deletion load_balanced/build.sbt
@@ -1,6 +1,6 @@

name := "nats-spark-connector"
version := "1.0.0"
version := "1.1.0"
scalaVersion := "2.12.14"

val sparkVersion = "3.3.0"
Expand Down
168 changes: 112 additions & 56 deletions load_balanced/src/main/scala/natsconnector/NatsConfig.scala
@@ -1,11 +1,11 @@
package natsconnector

import java.time.Duration

import io.nats.client.JetStream
import io.nats.client.JetStreamManagement
import io.nats.client.api.StreamInfo
import java.io.IOException

import java.io.{BufferedInputStream, FileInputStream, IOException}
import io.nats.client.JetStreamApiException
import io.nats.client.Nats
import io.nats.client.Options
Expand All @@ -18,7 +18,7 @@ import io.nats.client.ConnectionListener.Events
import io.nats.client.AuthHandler
import io.nats.client.api.StreamConfiguration
import io.nats.client.api.StorageType
import io.nats.client.{PushSubscribeOptions, PullSubscribeOptions}
import io.nats.client.{PullSubscribeOptions, PushSubscribeOptions}
import io.nats.client.api.ConsumerConfiguration
import io.nats.client.api.AckPolicy
import io.nats.client.api.RetentionPolicy
Expand All @@ -27,10 +27,14 @@ import scala.collection.JavaConverters._
import org.apache.log4j.PropertyConfigurator
import org.apache.log4j.Logger
import org.apache.log4j.Level

import java.util.Properties
import java.io.FileInputStream
import org.apache.spark.sql.SparkSession

import java.nio.file.{Files, Paths}
import java.security.{KeyStore, SecureRandom}
import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}

object NatsConfigSource {
val config = new NatsConfig(true)
}
Expand All @@ -39,7 +43,7 @@ object NatsConfigSink {
val config = new NatsConfig(false)
}

class NatsConfig(isSource: Boolean) {
class NatsConfig(isSource: Boolean) {
val isLocal = false
// Note on security:
// Set the environment variable NATS_NKEY to use challenge response authentication by setting a file containing your private key.
Expand All @@ -52,16 +56,16 @@ class NatsConfig(isSource: Boolean) {
var host = "0.0.0.0"
var port = "4222"
var server: Option[String] = None
var allowReconnect = true
var connectionTimeout = Duration.ofSeconds(20)
var pingInterval = Duration.ofSeconds(10)
var reconnectWait = Duration.ofSeconds(20)
var messageReceiveWaitTime = Duration.ofMillis(50)
var allowReconnect = true
var connectionTimeout = Duration.ofSeconds(20)
var pingInterval = Duration.ofSeconds(10)
var reconnectWait = Duration.ofSeconds(20)
var messageReceiveWaitTime = Duration.ofMillis(50)
var flushWaitTime = Duration.ofMillis(
0
) // how long to wait for a connection to flush all msgs; '0' waits forever
var msgFetchBatchSize = 100 // how many messages to get at once. Will get any messages where 1<bach_size<=100.
// If zero messages then subsciber will wait messageReceiveWaitTime before giving up.
// If zero messages then subsciber will wait messageReceiveWaitTime before giving up.

// ============== JetStream stream Config Values
// TODO: add replication configuration
Expand Down Expand Up @@ -203,29 +207,29 @@ class NatsConfig(isSource: Boolean) {
val logger:Logger = NatsLogger.logger
logger.debug(
"Current internal config state:\n"
+ s"host = ${this.host}\n"
+ s"port = ${this.port}\n"
+ s"server = ${this.server}\n"
+ s"allowReconnect = ${this.allowReconnect}\n"
+ s"connectionTimeout = ${this.connectionTimeout}\n"
+ s"pingInterval = ${this.pingInterval}\n"
+ s"reconnectWait = ${this.reconnectWait}\n"
+ s"messageReceiveWaitTime = ${this.messageReceiveWaitTime}\n"
+ s"flushWaitTime = ${this.flushWaitTime}\n"
+ s"msgFetchBatchSize = ${this.msgFetchBatchSize}\n"
+ s"streamName = ${this.streamName}\n"
+ s"storageType = ${this.storageType}\n"
+ s"streamSubjects = ${this.streamSubjects}\n"
+ s"durable = ${this.durable}\n"
+ s"ackPolicy = ${this.ackPolicy}\n"
+ s"retentionPolicy = ${this.retentionPolicy}\n"
+ s"deliverPolicy = ${this.deliverPolicy}\n"
+ s"msgAckWaitTime = ${this.msgAckWaitTime}\n"
+ s"dateTimeFormat = ${this.dateTimeFormat}\n"
+ s"numListeners = ${this.numListeners}\n"
+ s"[Connection] options = ${this.options}\n"
+ s"[Nats connection] nc = ${this.nc}\n"
+ s"[JetStream context] js= ${this.js}\n"
+ s"host = ${this.host}\n"
+ s"port = ${this.port}\n"
+ s"server = ${this.server}\n"
+ s"allowReconnect = ${this.allowReconnect}\n"
+ s"connectionTimeout = ${this.connectionTimeout}\n"
+ s"pingInterval = ${this.pingInterval}\n"
+ s"reconnectWait = ${this.reconnectWait}\n"
+ s"messageReceiveWaitTime = ${this.messageReceiveWaitTime}\n"
+ s"flushWaitTime = ${this.flushWaitTime}\n"
+ s"msgFetchBatchSize = ${this.msgFetchBatchSize}\n"
+ s"streamName = ${this.streamName}\n"
+ s"storageType = ${this.storageType}\n"
+ s"streamSubjects = ${this.streamSubjects}\n"
+ s"durable = ${this.durable}\n"
+ s"ackPolicy = ${this.ackPolicy}\n"
+ s"retentionPolicy = ${this.retentionPolicy}\n"
+ s"deliverPolicy = ${this.deliverPolicy}\n"
+ s"msgAckWaitTime = ${this.msgAckWaitTime}\n"
+ s"dateTimeFormat = ${this.dateTimeFormat}\n"
+ s"numListeners = ${this.numListeners}\n"
+ s"[Connection] options = ${this.options}\n"
+ s"[Nats connection] nc = ${this.nc}\n"
+ s"[JetStream context] js= ${this.js}\n"
)
}
}
Expand Down Expand Up @@ -273,28 +277,28 @@ class NatsConfig(isSource: Boolean) {
}
val subjectArray = this.streamSubjects.get.replace(" ", "").split(",")
subjectArray.zipWithIndex.foreach {
case (subject, idx) => {
val configBuilder = ConsumerConfiguration
.builder()
.ackWait(this.msgAckWaitTime)
.ackPolicy(this.ackPolicy)
.filterSubject(subject)
.deliverPolicy(this.deliverPolicy)
if(this.durable != None)
configBuilder.durable(s"${this.durable.get}-${idx}")
else {
// TODO: Add configBuilder.InactiveThreshold()
}
jsm.addOrUpdateConsumer(this.streamName.get, configBuilder.build())
case (subject, idx) => {
val configBuilder = ConsumerConfiguration
.builder()
.ackWait(this.msgAckWaitTime)
.ackPolicy(this.ackPolicy)
.filterSubject(subject)
.deliverPolicy(this.deliverPolicy)
if(this.durable.isDefined)
configBuilder.durable(s"${this.durable.get}-${idx}")
else {
// TODO: Add configBuilder.InactiveThreshold()
}
jsm.addOrUpdateConsumer(this.streamName.get, configBuilder.build())
}
}
this.nc.get.jetStream()
}

private def getStreamInfoOrNullIfNonExistent(
jsm: JetStreamManagement,
streamName: String
): StreamInfo = {
jsm: JetStreamManagement,
streamName: String
): StreamInfo = {
try {
return jsm.getStreamInfo(streamName)
} catch {
Expand All @@ -312,9 +316,9 @@ class NatsConfig(isSource: Boolean) {
}

private def createConnectionOptions(
server: String,
allowReconnect: Boolean
): Options = {
server: String,
allowReconnect: Boolean
): Options = {
val el = new ErrorListener() {
override def exceptionOccurred(conn: Connection, exp: Exception): Unit = {
System.out.println("Exception " + exp.getMessage());
Expand All @@ -325,9 +329,9 @@ class NatsConfig(isSource: Boolean) {
}

override def slowConsumerDetected(
conn: Connection,
consumer: Consumer
): Unit = {
conn: Connection,
consumer: Consumer
): Unit = {
System.out.println("Slow consumer");
}
}
Expand Down Expand Up @@ -363,7 +367,59 @@ class NatsConfig(isSource: Boolean) {
System.getenv("NATS_CREDS") != null && System.getenv("NATS_CREDS") != ""
) {
builder.authHandler(Nats.credentials(System.getenv("NATS_CREDS")));
} else if (System.getenv("NATS_TLS_KEY_STORE") != null && System.getenv("NATS_TLS_KEY_STORE") != "" && System.getenv("NATS_TLS_TRUST_STORE") != null && System.getenv("NATS_TLS_TRUST_STORE") != "") {

val tlsAlgo = if (System.getenv("NATS_TLS_ALGO") != null && System.getenv("NATS_TLS_ALGO") != "") {
System.getenv("NATS_TLS_ALGO")
} else "SunX509"

val instanceType = if (System.getenv("NATS_TLS_STORE_TYPE") != null && System.getenv("NATS_TLS_STORE_TYPE") != "") {
System.getenv("NATS_TLS_STORE_TYPE")
} else "JKS"

val keyStorePassword = if (System.getenv("NATS_TLS_KEY_STORE_PASSWORD") != null) {
System.getenv("NATS_TLS_KEY_STORE_PASSWORD").toCharArray
} else "".toCharArray

val trustStorePassword = if (System.getenv("NATS_TLS_TRUST_STORE_PASSWORD") != null) {
System.getenv("NATS_TLS_TRUST_STORE_PASSWORD").toCharArray
} else "".toCharArray


val ctx = javax.net.ssl.SSLContext.getInstance(Options.DEFAULT_SSL_PROTOCOL)

val keyStore = KeyStore.getInstance(instanceType)

val inputKeyF = new BufferedInputStream(Files.newInputStream(Paths.get(System.getenv("NATS_TLS_KEY_STORE"))))
try {
keyStore.load(inputKeyF, keyStorePassword)
} catch {
case e: Exception => System.out.println("Exception " + e.getMessage)
} finally {
if (inputKeyF != null) {inputKeyF.close()}
}

val kmsFactory = KeyManagerFactory.getInstance(tlsAlgo)
kmsFactory.init(keyStore, keyStorePassword)
val kms = kmsFactory.getKeyManagers

val trustStore = KeyStore.getInstance(instanceType)
val inputTrustF = new BufferedInputStream(Files.newInputStream(Paths.get(System.getenv("NATS_TLS_TRUST_STORE"))))
try {
trustStore.load(inputTrustF, trustStorePassword)
} catch {
case e: Exception => System.out.println("Exception " + e.getMessage)
} finally {if (inputTrustF != null) inputTrustF.close()}

val tmsFactory = TrustManagerFactory.getInstance(tlsAlgo)
tmsFactory.init(trustStore)
val tms = tmsFactory.getTrustManagers

ctx.init(kms, tms, new SecureRandom())

builder.sslContext(ctx)
}

return builder.build()
}
}
Expand Down
2 changes: 1 addition & 1 deletion load_balanced/src/test/scala/natstest/NatsTestDriver.scala
Expand Up @@ -16,7 +16,7 @@ object NatsTestDriver extends App {
Map(
"nats.stream.name" -> "TestStream",
"nats.stream.subjects" -> "test1, test2",
"nats.host" -> "0.0.0.0",
"nats.host" -> "localhost",
"nats.port" -> "4222",
"nats.msg.ack.wait.secs" -> "10",
"nats.durable.name" -> "Durable"
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion partitioned/build.sbt
@@ -1,6 +1,6 @@

name := "nats-spark-connector"
version := "1.0.0"
version := "1.1.0"
scalaVersion := "2.12.14"

val sparkVersion = "3.3.0"
Expand Down
64 changes: 61 additions & 3 deletions partitioned/src/main/scala/natsconnector/NatsConfig.scala
@@ -1,11 +1,9 @@
package natsconnector

import scala.collection.JavaConverters._

import java.time.Duration
import java.nio.charset.StandardCharsets
import java.io.IOException

import java.io.{BufferedInputStream, IOException}
import org.apache.spark.sql.SparkSession

//import org.slf4j.Logger
Expand Down Expand Up @@ -43,6 +41,10 @@ import java.util.Properties
import org.apache.log4j.PropertyConfigurator
import java.io.FileInputStream

import java.nio.file.{Files, Paths}
import java.security.{KeyStore, SecureRandom}
import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}


object NatsConfigSource {
val config = new NatsConfig(true)
Expand Down Expand Up @@ -220,7 +222,63 @@ class NatsConfig(isSource:Boolean) {
System.getenv("NATS_CREDS") != null && System.getenv("NATS_CREDS") != ""
) {
builder.authHandler(Nats.credentials(System.getenv("NATS_CREDS")));
} else if (System.getenv("NATS_TLS_KEY_STORE") != null && System.getenv("NATS_TLS_KEY_STORE") != "" && System.getenv("NATS_TLS_TRUST_STORE") != null && System.getenv("NATS_TLS_TRUST_STORE") != "") {

val tlsAlgo = if (System.getenv("NATS_TLS_ALGO") != null && System.getenv("NATS_TLS_ALGO") != "") {
System.getenv("NATS_TLS_ALGO")
} else "SunX509"

val instanceType = if (System.getenv("NATS_TLS_STORE_TYPE") != null && System.getenv("NATS_TLS_STORE_TYPE") != "") {
System.getenv("NATS_TLS_STORE_TYPE")
} else "JKS"

val keyStorePassword = if (System.getenv("NATS_TLS_KEY_STORE_PASSWORD") != null) {
System.getenv("NATS_TLS_KEY_STORE_PASSWORD").toCharArray
} else "".toCharArray

val trustStorePassword = if (System.getenv("NATS_TLS_TRUST_STORE_PASSWORD") != null) {
System.getenv("NATS_TLS_TRUST_STORE_PASSWORD").toCharArray
} else "".toCharArray


val ctx = javax.net.ssl.SSLContext.getInstance(Options.DEFAULT_SSL_PROTOCOL)

val keyStore = KeyStore.getInstance(instanceType)

val inputKeyF = new BufferedInputStream(Files.newInputStream(Paths.get(System.getenv("NATS_TLS_KEY_STORE"))))
try {
keyStore.load(inputKeyF, keyStorePassword)
} catch {
case e: Exception => System.out.println("Exception " + e.getMessage)
} finally {
if (inputKeyF != null) {
inputKeyF.close()
}
}

val kmsFactory = KeyManagerFactory.getInstance(tlsAlgo)
kmsFactory.init(keyStore, keyStorePassword)
val kms = kmsFactory.getKeyManagers

val trustStore = KeyStore.getInstance(instanceType)
val inputTrustF = new BufferedInputStream(Files.newInputStream(Paths.get(System.getenv("NATS_TLS_TRUST_STORE"))))
try {
trustStore.load(inputTrustF, trustStorePassword)
} catch {
case e: Exception => System.out.println("Exception " + e.getMessage)
} finally {
if (inputTrustF != null) inputTrustF.close()
}

val tmsFactory = TrustManagerFactory.getInstance(tlsAlgo)
tmsFactory.init(trustStore)
val tms = tmsFactory.getTrustManagers

ctx.init(kms, tms, new SecureRandom())

builder.sslContext(ctx)
}

return builder.build()
}

Expand Down

0 comments on commit 044e54c

Please sign in to comment.