generated from JetBrains/intellij-platform-plugin-template
/
OpenAIProvider.kt
174 lines (146 loc) ยท 6.22 KB
/
OpenAIProvider.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
package cc.unitmesh.devti.llms.openai
import cc.unitmesh.devti.gui.chat.ChatRole
import cc.unitmesh.devti.llms.LLMProvider
import cc.unitmesh.devti.coder.recording.EmptyRecording
import cc.unitmesh.devti.coder.recording.JsonlRecording
import cc.unitmesh.devti.coder.recording.Recording
import cc.unitmesh.devti.coder.recording.RecordingInstruction
import cc.unitmesh.devti.settings.AutoDevSettingsState
import cc.unitmesh.devti.settings.coder.coderSetting
import cc.unitmesh.devti.settings.SELECT_CUSTOM_MODEL
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import com.theokanning.openai.client.OpenAiApi
import com.theokanning.openai.completion.chat.ChatCompletionRequest
import com.theokanning.openai.completion.chat.ChatMessage
import com.theokanning.openai.completion.chat.ChatMessageRole
import com.theokanning.openai.service.OpenAiService
import com.theokanning.openai.service.OpenAiService.defaultClient
import com.theokanning.openai.service.OpenAiService.defaultObjectMapper
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow
import retrofit2.Retrofit
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory
import retrofit2.converter.jackson.JacksonConverterFactory
import java.time.Duration
@Service(Service.Level.PROJECT)
class OpenAIProvider(val project: Project) : LLMProvider {
private val service: OpenAiService
get() {
if (openAiKey.isEmpty()) {
logger.error("openAiKey is empty")
throw IllegalStateException("openAiKey is empty")
}
var openAiProxy = AutoDevSettingsState.getInstance().customOpenAiHost
return if (openAiProxy.isEmpty()) {
OpenAiService(openAiKey, timeout)
} else {
if (!openAiProxy.endsWith("/")) {
openAiProxy += "/"
}
val mapper = defaultObjectMapper()
val client = defaultClient(openAiKey, timeout)
val retrofit = Retrofit.Builder()
.baseUrl(openAiProxy)
.client(client)
.addConverterFactory(JacksonConverterFactory.create(mapper))
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.build()
val api = retrofit.create(OpenAiApi::class.java)
OpenAiService(api)
}
}
private val timeout = Duration.ofSeconds(600)
private val openAiVersion: String
get() {
val customModel = AutoDevSettingsState.getInstance().customModel
if(AutoDevSettingsState.getInstance().openAiModel == SELECT_CUSTOM_MODEL) {
AutoDevSettingsState.getInstance().openAiModel = customModel
}
return AutoDevSettingsState.getInstance().openAiModel
}
private val openAiKey: String
get() = AutoDevSettingsState.getInstance().openAiKey
private val maxTokenLength: Int
get() = AutoDevSettingsState.getInstance().fetchMaxTokenLength()
private val messages: MutableList<ChatMessage> = ArrayList()
private var historyMessageLength: Int = 0
private val recording: Recording
get() {
if (project.coderSetting.state.recordingInLocal) {
return project.service<JsonlRecording>()
}
return EmptyRecording()
}
override fun clearMessage() {
messages.clear()
historyMessageLength = 0
}
override fun appendLocalMessage(msg: String, role: ChatRole) {
val message = ChatMessage(role.roleName(), msg)
messages.add(message)
}
override fun prompt(promptText: String): String {
val completionRequest = prepareRequest(promptText, "")
val completion = service.createChatCompletion(completionRequest)
val output = completion
.choices[0].message.content
return output
}
@OptIn(ExperimentalCoroutinesApi::class)
override fun stream(promptText: String, systemPrompt: String, keepHistory: Boolean): Flow<String> {
if (!keepHistory) {
clearMessage()
}
if (project.coderSetting.state.noChatHistory) {
messages.clear()
}
var output = ""
val completionRequest = prepareRequest(promptText, systemPrompt)
return callbackFlow {
withContext(Dispatchers.IO) {
service.streamChatCompletion(completionRequest)
.doOnError { error ->
logger.error("Error in stream", error)
trySend(error.message ?: "Error occurs")
}
.blockingForEach { response ->
if (response.choices.isNotEmpty()) {
val completion = response.choices[0].message
if (completion != null && completion.content != null) {
output += completion.content
trySend(completion.content)
}
}
}
recording.write(RecordingInstruction(promptText, output))
close()
}
}
}
private fun prepareRequest(promptText: String, systemPrompt: String): ChatCompletionRequest? {
if (messages.isEmpty()) {
val systemMessage = ChatMessage(ChatMessageRole.SYSTEM.value(), systemPrompt)
messages.add(systemMessage)
}
val systemMessage = ChatMessage(ChatMessageRole.USER.value(), promptText)
historyMessageLength += promptText.length
if (historyMessageLength > maxTokenLength) {
messages.clear()
}
messages.add(systemMessage)
logger.info("messages length: ${messages.size}")
return ChatCompletionRequest.builder()
.model(openAiVersion)
.temperature(0.0)
.messages(messages)
.build()
}
companion object {
private val logger: Logger = logger<OpenAIProvider>()
}
}