Skip to content

Commit

Permalink
Merge pull request #110 from massimosiani/doobie-rc4
Browse files Browse the repository at this point in the history
Doobie rc4
  • Loading branch information
voidcontext committed Sep 28, 2023
2 parents b40ced6 + 48cba1e commit 89efeb5
Show file tree
Hide file tree
Showing 7 changed files with 456 additions and 14 deletions.
22 changes: 19 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import microsites.MicrositesPlugin.autoImport.micrositeDescription

val scala213Version = "2.13.10"
val scala213Version = "2.13.12"
val scala3Version = "3.3.0"

val scalaVersions = Seq(scala213Version, scala3Version)
Expand Down Expand Up @@ -68,8 +68,9 @@ val http4sMilestoneVersion = "1.0.0-M40"
val http4sStableVersion = "0.23.23"
val circeVersion = "0.14.3"
val slf4jVersion = "1.7.36"
val fs2Version = "3.8.0"
val doobieVersion = "1.0.0-RC2"
val fs2Version = "3.9.1"
val doobieVersion = "1.0.0-RC4"
val doobieLegacyVersion = "1.0.0-RC2"

lazy val natchezDatadog = projectMatrix
.in(file("natchez-extras-datadog"))
Expand Down Expand Up @@ -206,6 +207,20 @@ lazy val natchezDoobie = projectMatrix
)
.dependsOn(core)

lazy val natchezDoobieLegacy = projectMatrix
.in(file("natchez-extras-doobie-legacy"))
.jvmPlatform(scalaVersions = scalaVersions)
.enablePlugins(GitVersioning)
.settings(common :+ (name := "natchez-extras-doobie-legacy"))
.settings(
libraryDependencies ++= Seq(
"org.tpolecat" %% "natchez-core" % natchezVersion,
"org.tpolecat" %% "doobie-core" % doobieLegacyVersion,
"org.tpolecat" %% "doobie-h2" % doobieLegacyVersion % Test
)
)
.dependsOn(core)

lazy val core = projectMatrix
.in(file("natchez-extras-core"))
.jvmPlatform(scalaVersions = scalaVersions)
Expand Down Expand Up @@ -305,6 +320,7 @@ lazy val root = (project in file("."))
.aggregate(natchezCombine.projectRefs: _*)
.aggregate(natchezSlf4j.projectRefs: _*)
.aggregate(natchezDoobie.projectRefs: _*)
.aggregate(natchezDoobieLegacy.projectRefs: _*)
.aggregate(natchezLog4Cats.projectRefs: _*)
.aggregate(natchezHttp4s.projectRefs: _*)
.aggregate(natchezFs2.projectRefs: _*)
Expand Down
3 changes: 2 additions & 1 deletion docs/docs/docs/natchez-doobie.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ object NatchezDoobie extends IOApp {
driver = "org.postgresql.Driver",
url = "jdbc:postgresql:example",
user = "postgres",
pass = "password" // of course don't hard code these details in your applications!
password = "password", // of course don't hard code these details in your applications!
logHandler = None,
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
package com.ovoenergy.natchez.extras.doobie

import java.io.{InputStream, Reader}
import java.net.URL
import java.sql.{Array => _, _}
import java.util.Calendar
import scala.annotation.nowarn

/**
* This is an absolutely abominable brute force solution to linking PreparedStatements
* with a SQL string so we can include it in traces but hey I figure it is a one time cost
* Pretend this doesn't exist and you never had to see it
*/
@nowarn
private[doobie] case class TracedStatement(
p: PreparedStatement,
queryString: String
) extends PreparedStatement {
def executeQuery(): ResultSet = p.executeQuery()

def executeUpdate(): Int = p.executeUpdate()

def setNull(parameterIndex: Int, sqlType: Int): Unit = p.setNull(parameterIndex, sqlType)

def setBoolean(parameterIndex: Int, x: Boolean): Unit = p.setBoolean(parameterIndex, x)

def setByte(parameterIndex: Int, x: Byte): Unit = p.setByte(parameterIndex, x)

def setShort(parameterIndex: Int, x: Short): Unit = p.setShort(parameterIndex, x)

def setInt(parameterIndex: Int, x: Int): Unit = p.setInt(parameterIndex, x)

def setLong(parameterIndex: Int, x: Long): Unit = p.setLong(parameterIndex, x)

def setFloat(parameterIndex: Int, x: Float): Unit = p.setFloat(parameterIndex, x)

def setDouble(parameterIndex: Int, x: Double): Unit = p.setDouble(parameterIndex, x)

def setBigDecimal(parameterIndex: Int, x: java.math.BigDecimal): Unit = p.setBigDecimal(parameterIndex, x)

def setString(parameterIndex: Int, x: String): Unit = p.setString(parameterIndex, x)

def setBytes(parameterIndex: Int, x: Array[Byte]): Unit = p.setBytes(parameterIndex, x)

def setDate(parameterIndex: Int, x: Date): Unit = p.setDate(parameterIndex, x)

def setTime(parameterIndex: Int, x: Time): Unit = p.setTime(parameterIndex, x)

def setTimestamp(parameterIndex: Int, x: Timestamp): Unit = p.setTimestamp(parameterIndex, x)

def setAsciiStream(parameterIndex: Int, x: InputStream, length: Int): Unit =
p.setAsciiStream(parameterIndex, x, length)

def setUnicodeStream(parameterIndex: Int, x: InputStream, length: Int): Unit =
p.setUnicodeStream(parameterIndex, x, length)

def setBinaryStream(parameterIndex: Int, x: InputStream, length: Int): Unit =
p.setBinaryStream(parameterIndex, x, length)

def clearParameters(): Unit = p.clearParameters()

def setObject(parameterIndex: Int, x: Any, targetSqlType: Int): Unit =
p.setObject(parameterIndex, x, targetSqlType)

def setObject(parameterIndex: Int, x: Any): Unit = p.setObject(parameterIndex, x)

def execute(): Boolean = p.execute()

def addBatch(): Unit = p.addBatch()

def setCharacterStream(parameterIndex: Int, reader: Reader, length: Int): Unit =
p.setCharacterStream(parameterIndex, reader, length)

def setRef(parameterIndex: Int, x: Ref): Unit = p.setRef(parameterIndex, x)

def setBlob(parameterIndex: Int, x: Blob): Unit = p.setBlob(parameterIndex, x)

def setClob(parameterIndex: Int, x: Clob): Unit = p.setClob(parameterIndex, x)

def setArray(parameterIndex: Int, x: java.sql.Array): Unit = p.setArray(parameterIndex, x)

def getMetaData: ResultSetMetaData = p.getMetaData

def setDate(parameterIndex: Int, x: Date, cal: Calendar): Unit = p.setDate(parameterIndex, x, cal)

def setTime(parameterIndex: Int, x: Time, cal: Calendar): Unit = p.setTime(parameterIndex, x, cal)

def setTimestamp(parameterIndex: Int, x: Timestamp, cal: Calendar): Unit =
p.setTimestamp(parameterIndex, x, cal)

def setNull(parameterIndex: Int, sqlType: Int, typeName: String): Unit =
p.setNull(parameterIndex, sqlType, typeName)

def setURL(parameterIndex: Int, x: URL): Unit = p.setURL(parameterIndex, x)

def getParameterMetaData: ParameterMetaData = p.getParameterMetaData

def setRowId(parameterIndex: Int, x: RowId): Unit = p.setRowId(parameterIndex, x)

def setNString(parameterIndex: Int, value: String): Unit = p.setNString(parameterIndex, value)

def setNCharacterStream(parameterIndex: Int, value: Reader, length: Long): Unit =
p.setNCharacterStream(parameterIndex, value, length)

def setNClob(parameterIndex: Int, value: NClob): Unit = p.setNClob(parameterIndex, value)

def setClob(parameterIndex: Int, reader: Reader, length: Long): Unit =
p.setClob(parameterIndex, reader, length)

def setBlob(parameterIndex: Int, inputStream: InputStream, length: Long): Unit =
p.setBlob(parameterIndex, inputStream, length)

def setNClob(parameterIndex: Int, reader: Reader, length: Long): Unit =
p.setNClob(parameterIndex, reader, length)

def setSQLXML(parameterIndex: Int, xmlObject: SQLXML): Unit = p.setSQLXML(parameterIndex, xmlObject)

def setObject(parameterIndex: Int, x: Any, targetSqlType: Int, scaleOrLength: Int): Unit =
p.setObject(parameterIndex, x, targetSqlType, scaleOrLength)

def setAsciiStream(parameterIndex: Int, x: InputStream, length: Long): Unit =
p.setAsciiStream(parameterIndex, x, length)

def setBinaryStream(parameterIndex: Int, x: InputStream, length: Long): Unit =
p.setBinaryStream(parameterIndex, x, length)

def setCharacterStream(parameterIndex: Int, reader: Reader, length: Long): Unit =
p.setCharacterStream(parameterIndex, reader, length)

def setAsciiStream(parameterIndex: Int, x: InputStream): Unit = p.setAsciiStream(parameterIndex, x)

def setBinaryStream(parameterIndex: Int, x: InputStream): Unit = p.setBinaryStream(parameterIndex, x)

def setCharacterStream(parameterIndex: Int, reader: Reader): Unit =
p.setCharacterStream(parameterIndex, reader)

def setNCharacterStream(parameterIndex: Int, value: Reader): Unit =
p.setNCharacterStream(parameterIndex, value)

def setClob(parameterIndex: Int, reader: Reader): Unit = p.setClob(parameterIndex, reader)

def setBlob(parameterIndex: Int, inputStream: InputStream): Unit = p.setBlob(parameterIndex, inputStream)

def setNClob(parameterIndex: Int, reader: Reader): Unit = p.setNClob(parameterIndex, reader)

def executeQuery(sql: String): ResultSet = p.executeQuery(sql)

def executeUpdate(sql: String): Int = p.executeUpdate(sql)

def close(): Unit = p.close()

def getMaxFieldSize: Int = p.getMaxFieldSize

def setMaxFieldSize(max: Int): Unit = p.setMaxFieldSize(max)

def getMaxRows: Int = p.getMaxRows

def setMaxRows(max: Int): Unit = p.setMaxRows(max)

def setEscapeProcessing(enable: Boolean): Unit = p.setEscapeProcessing(enable)

def getQueryTimeout: Int = p.getQueryTimeout

def setQueryTimeout(seconds: Int): Unit = p.setQueryTimeout(seconds)

def cancel(): Unit = p.cancel()

def getWarnings: SQLWarning = p.getWarnings

def clearWarnings(): Unit = p.clearWarnings()

def setCursorName(name: String): Unit = p.setCursorName(name)

def execute(sql: String): Boolean = p.execute(sql)

def getResultSet: ResultSet = p.getResultSet

def getUpdateCount: Int = p.getUpdateCount

def getMoreResults: Boolean = p.getMoreResults()

def setFetchDirection(direction: Int): Unit = p.setFetchDirection(direction)

def getFetchDirection: Int = p.getFetchDirection

def setFetchSize(rows: Int): Unit = p.setFetchSize(rows)

def getFetchSize: Int = p.getFetchSize

def getResultSetConcurrency: Int = p.getResultSetConcurrency

def getResultSetType: Int = p.getResultSetType

def addBatch(sql: String): Unit = p.addBatch(sql)

def clearBatch(): Unit = p.clearBatch()

def executeBatch(): Array[Int] = p.executeBatch()

def getConnection: Connection = p.getConnection

def getMoreResults(current: Int): Boolean = p.getMoreResults(current)

def getGeneratedKeys: ResultSet = p.getGeneratedKeys

def executeUpdate(sql: String, autoGeneratedKeys: Int): Int = p.executeUpdate(sql, autoGeneratedKeys)

def executeUpdate(sql: String, columnIndexes: Array[Int]): Int = p.executeUpdate(sql, columnIndexes)

def executeUpdate(sql: String, columnNames: Array[String]): Int = p.executeUpdate(sql, columnNames)

def execute(sql: String, autoGeneratedKeys: Int): Boolean = p.execute(sql, autoGeneratedKeys)

def execute(sql: String, columnIndexes: Array[Int]): Boolean = p.execute(sql, columnIndexes)

def execute(sql: String, columnNames: Array[String]): Boolean = p.execute(sql, columnNames)

def getResultSetHoldability: Int = p.getResultSetHoldability

def isClosed: Boolean = p.isClosed

def setPoolable(poolable: Boolean): Unit = p.setPoolable(poolable)

def isPoolable: Boolean = p.isPoolable

def closeOnCompletion(): Unit = p.closeOnCompletion()

def isCloseOnCompletion: Boolean = p.isCloseOnCompletion

def unwrap[T](iface: Class[T]): T = p.unwrap(iface)

def isWrapperFor(iface: Class[_]): Boolean = p.isWrapperFor(iface)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package com.ovoenergy.natchez.extras.doobie

import cats.data.Kleisli
import cats.effect.Async
import cats.implicits.catsSyntaxFlatMapOps
import com.ovoenergy.natchez.extras.core.Config
import com.ovoenergy.natchez.extras.core.Config.ServiceAndResource
import doobie.{KleisliInterpreter, WeakAsync}
import doobie.util.transactor.Transactor
import natchez.{Span, Trace}

import java.sql.{Connection, PreparedStatement, ResultSet}

object TracedTransactor {
private val DefaultResourceName = "db.execute"

type Traced[F[_], A] = Kleisli[F, Span[F], A]
def apply[F[_]: Async](
service: String,
transactor: Transactor[F]
): Transactor[Traced[F, *]] = {
val kleisliTransactor = transactor
.mapK(Kleisli.liftK[F, Span[F]])(implicitly, Async.asyncForKleisli(implicitly))
trace(ServiceAndResource(s"$service-db", DefaultResourceName), kleisliTransactor)
}

private val commentNamedQueryRegEx = """--\s*Name:\s*(\w+)""".r

private def extractQueryNameOrSql(sql: String): String =
commentNamedQueryRegEx.findFirstMatchIn(sql).flatMap(m => Option(m.group(1))).getOrElse(sql)

private def formatQuery(q: String): String =
q.replace("\n", " ").replaceAll("\\s+", " ").trim()

def trace[F[_]: Trace: Async](
config: Config,
transactor: Transactor[F]
): Transactor[F] =
transactor
.copy(
interpret0 = createInterpreter(config, Async[F]).ConnectionInterpreter
)

private def createInterpreter[F[_]: Trace](config: Config, F: Async[F]): KleisliInterpreter[F] = {
new KleisliInterpreter[F] {
implicit val asyncM: WeakAsync[F] =
WeakAsync.doobieWeakAsyncForAsync(F)

override lazy val PreparedStatementInterpreter: PreparedStatementInterpreter =
new PreparedStatementInterpreter {

type TracedOp[A] = Kleisli[F, PreparedStatement, A] //PreparedStatement => F[A]

def runTraced[A](f: TracedOp[A]): TracedOp[A] =
Kleisli {
case TracedStatement(p, sql) =>
Trace[F].span(config.fullyQualifiedSpanName(formatQuery(extractQueryNameOrSql(sql))))(
Trace[F].put("span.type" -> "db") >> f(p)
)
case a =>
f(a)
}

override val executeBatch: TracedOp[Array[Int]] =
runTraced(super.executeBatch)

override val executeLargeBatch: TracedOp[Array[Long]] =
runTraced(super.executeLargeBatch)

override val execute: TracedOp[Boolean] =
runTraced(super.execute)

override val executeUpdate: TracedOp[Int] =
runTraced(super.executeUpdate)

override val executeQuery: TracedOp[ResultSet] =
runTraced(super.executeQuery)
}

override lazy val ConnectionInterpreter: ConnectionInterpreter =
new ConnectionInterpreter {
override def prepareStatement(a: String): Kleisli[F, Connection, PreparedStatement] =
super.prepareStatement(a).map(TracedStatement(_, a): PreparedStatement)
}
}
}
}

0 comments on commit 89efeb5

Please sign in to comment.