/
TestServer.kt
120 lines (99 loc) · 3.75 KB
/
TestServer.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/*
* Thrifty
*
* Copyright (c) Microsoft Corporation
*
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF TITLE,
* FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR NON-INFRINGEMENT.
*
* See the Apache Version 2.0 License for specific language governing permissions and limitations under the License.
*/
package com.microsoft.thrifty.integration.conformance.server
import com.microsoft.thrifty.integration.kgen.ThriftTestProcessor
import com.microsoft.thrifty.protocol.BinaryProtocol
import com.microsoft.thrifty.protocol.CompactProtocol
import com.microsoft.thrifty.protocol.JsonProtocol
import com.microsoft.thrifty.protocol.Protocol
import com.microsoft.thrifty.testing.ServerProtocol
import com.microsoft.thrifty.transport.Transport
import com.sun.net.httpserver.HttpContext
import com.sun.net.httpserver.HttpExchange
import com.sun.net.httpserver.HttpServer
import kotlinx.coroutines.runBlocking
import okio.Buffer
import org.junit.jupiter.api.extension.AfterEachCallback
import org.junit.jupiter.api.extension.BeforeEachCallback
import org.junit.jupiter.api.extension.Extension
import org.junit.jupiter.api.extension.ExtensionContext
import java.net.InetSocketAddress
import java.util.concurrent.Executors
class TestServer(private val protocol: ServerProtocol = ServerProtocol.BINARY) : Extension, BeforeEachCallback,
AfterEachCallback {
val processor = ThriftTestProcessor(ThriftTestHandler())
private var server: HttpServer? = null
class TestTransport(
val b: Buffer = Buffer()
) : Transport {
override fun read(buffer: ByteArray, offset: Int, count: Int) = b.read(buffer, offset, count)
override fun write(buffer: ByteArray, offset: Int, count: Int) {
b.write(buffer, offset, count)
}
override fun flush() = b.flush()
override fun close() = b.close()
}
private fun handleRequest(exchange: HttpExchange) {
val inputTransport = TestTransport(Buffer().readFrom(exchange.requestBody))
val outputTransport = TestTransport()
val input = protocolFactory(inputTransport)
val output = protocolFactory(outputTransport)
runBlocking {
processor.process(input, output)
}
exchange.sendResponseHeaders(200, outputTransport.b.size)
exchange.responseBody.use {
outputTransport.b.writeTo(it)
}
}
fun run() {
server = HttpServer.create(InetSocketAddress("localhost", 0), 0).apply {
val context: HttpContext = createContext("/")
context.setHandler(::handleRequest)
executor = Executors.newSingleThreadExecutor()
start()
}
}
fun port(): Int {
return server!!.address.port
}
override fun beforeEach(context: ExtensionContext) {
run()
}
override fun afterEach(context: ExtensionContext) {
cleanupServer()
}
fun close() {
cleanupServer()
}
private fun cleanupServer() {
server?.let {
it.stop(0)
server = null
}
}
private fun protocolFactory(transport: Transport): Protocol = when (protocol) {
ServerProtocol.BINARY -> BinaryProtocol(transport)
ServerProtocol.COMPACT -> CompactProtocol(transport)
ServerProtocol.JSON -> JsonProtocol(transport)
else -> throw AssertionError("Invalid protocol value: $protocol")
}
}