Skip to content

Commit

Permalink
fix: EXPOSED-382 ClassCastException when uuid().references() is used …
Browse files Browse the repository at this point in the history
…with EntityID column

Using references() invoked on a UUIDColumnType that targets an EntityIDColumnType<UUID>
causes a ClassCastException. This occurs because PreparedStatementApi.fillParameters()
tries to get the valueToDB() from EntityIDColumnType rather than the underlying UUIDColumnType
and it is not possible to cast the refValue UUID to EntityID<UUID>.

This exception does not happen with Int- or LongColumnType because their EntityID
variants are wrapped as AutoIncColumnTypes that delegate to the correct underlying
type. And this also doesn't happen if reference() is used because it creates an
EntityIDColumnType under-the-hood which avoids type mismatch.

This was most likely always happening but only causes an issue now that a stricter
column type-safety has been introduced, which does not tolerate well a foreign
key constraint between columns of different types.
  • Loading branch information
bog-walk committed May 10, 2024
1 parent db4c708 commit dac6b4f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 9 deletions.
Expand Up @@ -30,7 +30,7 @@ class Column<T>(
val referee: Column<*>?
get() = foreignKey?.targetOf(this)

/** Returns the column that this column references, casted as a column of type [S], or `null` if the cast fails. */
/** Returns the column that this column references, cast as a column of type [S], or `null` if the cast fails. */
@Suppress("UNCHECKED_CAST")
fun <S : T> referee(): Column<S>? = referee as? Column<S>

Expand Down
Expand Up @@ -125,15 +125,15 @@ open class Entity<ID : Comparable<ID>>(val id: EntityID<ID>) {
}
}
else -> {
// @formatter:off
val castReferee = reference.referee<REF>()!!
val baseReferee = (castReferee.columnType as? EntityIDColumnType<REF>)?.idColumn ?: castReferee
factory.findWithCacheCondition({
reference.referee!!.getValue(this, desc) == refValue
}) {
reference.referee<REF>()!! eq refValue
baseReferee eq refValue
}.singleOrNull()?.also {
storeReferenceInCache(reference, it)
}
// @formatter:on
}
} ?: error("Cannot find ${factory.table.tableName} WHERE id=$refValue")
}
Expand Down
Expand Up @@ -4,7 +4,10 @@ import org.jetbrains.exposed.dao.LongEntity
import org.jetbrains.exposed.dao.LongEntityClass
import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.dao.id.LongIdTable
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.exists
import org.jetbrains.exposed.sql.insert
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.junit.Test
Expand All @@ -13,30 +16,43 @@ object LongIdTables {
object Cities : LongIdTable() {
val name = varchar("name", 50)
}

class City(id: EntityID<Long>) : LongEntity(id) {
companion object : LongEntityClass<City>(Cities)
var name by Cities.name
}

object People : LongIdTable() {
val name = varchar("name", 80)
val cityId = reference("city_id", Cities)
}

class Person(id: EntityID<Long>) : LongEntity(id) {
companion object : LongEntityClass<Person>(People)
var name by People.name
var city by City referencedOn People.cityId
}

object Towns : LongIdTable("towns") {
val cityId: Column<Long> = long("city_id").references(Cities.id)
}

class Town(id: EntityID<Long>) : LongEntity(id) {
companion object : LongEntityClass<Town>(Towns)
var city by City referencedOn Towns.cityId
}
}
class LongIdTableEntityTest : DatabaseTestsBase() {

@Test fun `create tables`() {
@Test
fun `create tables`() {
withTables(LongIdTables.Cities, LongIdTables.People) {
assertEquals(true, LongIdTables.Cities.exists())
assertEquals(true, LongIdTables.People.exists())
}
}

@Test fun `create records`() {
@Test
fun `create records`() {
withTables(LongIdTables.Cities, LongIdTables.People) {
val mumbai = LongIdTables.City.new { name = "Mumbai" }
val pune = LongIdTables.City.new { name = "Pune" }
Expand Down Expand Up @@ -64,7 +80,8 @@ class LongIdTableEntityTest : DatabaseTestsBase() {
}
}

@Test fun `update and delete records`() {
@Test
fun `update and delete records`() {
withTables(LongIdTables.Cities, LongIdTables.People) {
val mumbai = LongIdTables.City.new { name = "Mumbai" }
val pune = LongIdTables.City.new { name = "Pune" }
Expand Down Expand Up @@ -93,4 +110,19 @@ class LongIdTableEntityTest : DatabaseTestsBase() {
assertEquals(false, allPeople.contains(Pair("Tanu Arora", "Pune")))
}
}

@Test
fun testForeignKeyBetweenLongAndEntityIDColumns() {
withTables(LongIdTables.Cities, LongIdTables.Towns) {
val cId = LongIdTables.Cities.insertAndGetId {
it[name] = "City A"
}
LongIdTables.Towns.insert {
it[cityId] = cId.value
}

val town1 = LongIdTables.Town.all().single()
assertEquals(cId, town1.city.id)
}
}
}
Expand Up @@ -4,7 +4,10 @@ import org.jetbrains.exposed.dao.UUIDEntity
import org.jetbrains.exposed.dao.UUIDEntityClass
import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.dao.id.UUIDTable
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.exists
import org.jetbrains.exposed.sql.insert
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.junit.Test
Expand Down Expand Up @@ -47,10 +50,18 @@ object UUIDTables {
var city by City.referencedOn(Addresses.city)
var address by Addresses.address
}

object Towns : UUIDTable("towns") {
val cityId: Column<UUID> = uuid("city_id").references(Cities.id)
}

class Town(id: EntityID<UUID>) : UUIDEntity(id) {
companion object : UUIDEntityClass<Town>(Towns)
var city by City referencedOn Towns.cityId
}
}

class UUIDTableEntityTest : DatabaseTestsBase() {

@Test
fun `create tables`() {
withTables(UUIDTables.Cities, UUIDTables.People) {
Expand Down Expand Up @@ -149,4 +160,19 @@ class UUIDTableEntityTest : DatabaseTestsBase() {
assertEquals("address2", address2.address)
}
}

@Test
fun testForeignKeyBetweenUUIDAndEntityIDColumns() {
withTables(UUIDTables.Cities, UUIDTables.Towns) {
val cId = UUIDTables.Cities.insertAndGetId {
it[name] = "City A"
}
UUIDTables.Towns.insert {
it[cityId] = cId.value
}

val town1 = UUIDTables.Town.all().single()
assertEquals(cId, town1.city.id)
}
}
}

0 comments on commit dac6b4f

Please sign in to comment.