diff --git a/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt b/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt index d35c3ed33e..63de102d4d 100644 --- a/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt +++ b/src/main/kotlin/cc/unitmesh/devti/llms/openai/OpenAIProvider.kt @@ -8,6 +8,7 @@ import cc.unitmesh.devti.coder.recording.Recording import cc.unitmesh.devti.coder.recording.RecordingInstruction import cc.unitmesh.devti.settings.AutoDevSettingsState import cc.unitmesh.devti.settings.coder.coderSetting +import cc.unitmesh.devti.settings.SELECT_CUSTOM_MODEL import com.intellij.openapi.components.Service import com.intellij.openapi.components.service import com.intellij.openapi.diagnostic.Logger @@ -63,7 +64,13 @@ class OpenAIProvider(val project: Project) : LLMProvider { private val timeout = Duration.ofSeconds(600) private val openAiVersion: String - get() = AutoDevSettingsState.getInstance().openAiModel + get() { + val customModel = AutoDevSettingsState.getInstance().customModel + if(AutoDevSettingsState.getInstance().openAiModel == SELECT_CUSTOM_MODEL) { + AutoDevSettingsState.getInstance().openAiModel = customModel + } + return AutoDevSettingsState.getInstance().openAiModel + } private val openAiKey: String get() = AutoDevSettingsState.getInstance().openAiKey diff --git a/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt b/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt index badb23925e..086169f86d 100644 --- a/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt +++ b/src/main/kotlin/cc/unitmesh/devti/settings/AutoDevSettingsState.kt @@ -21,6 +21,7 @@ class AutoDevSettingsState : PersistentStateComponent { var customEngineServer = "" var customEngineToken = "" var customPrompts = "" + var customModel = "" // 星火有三个版本 https://console.xfyun.cn/services/bm3 var xingHuoApiVersion = XingHuoApiVersion.V3 diff --git a/src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt b/src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt index 01710ec47e..b6c9b1d494 100644 --- a/src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt +++ b/src/main/kotlin/cc/unitmesh/devti/settings/Constants.kt @@ -1,6 +1,6 @@ package cc.unitmesh.devti.settings -val OPENAI_MODEL = arrayOf("gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4") +val OPENAI_MODEL = arrayOf("gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "custom") val AI_ENGINES = arrayOf("OpenAI", "Custom", "Azure", "XingHuo") enum class AIEngines { @@ -35,3 +35,4 @@ val DEFAULT_AI_MODEL = OPENAI_MODEL[0] val HUMAN_LANGUAGES = arrayOf("English", "中文") val DEFAULT_HUMAN_LANGUAGE = HUMAN_LANGUAGES[0] val MAX_TOKEN_LENGTH = 4000 +val SELECT_CUSTOM_MODEL = "custom" diff --git a/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt b/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt index b6b43bc98f..4375184b2b 100644 --- a/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt +++ b/src/main/kotlin/cc/unitmesh/devti/settings/LLMSettingComponent.kt @@ -23,6 +23,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { private val maxTokenLengthParam by LLMParam.creating { Editable(settings.maxTokenLength) } private val openAIModelsParam by LLMParam.creating { ComboBox(settings.openAiModel, OPENAI_MODEL.toList()) } private val openAIKeyParam by LLMParam.creating { Password(settings.openAiKey) } + private val customModelParam: LLMParam by LLMParam.creating { Editable(settings.customModel) } private val customOpenAIHostParam: LLMParam by LLMParam.creating { Editable(settings.customOpenAiHost) } private val gitTypeParam: LLMParam by LLMParam.creating { ComboBox(settings.gitType, GIT_TYPE.toList()) } @@ -78,6 +79,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { AIEngines.OpenAI to listOf( openAIModelsParam, openAIKeyParam, + customModelParam, customOpenAIHostParam, ), AIEngines.Custom to listOf( @@ -89,10 +91,10 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { ), AIEngines.XingHuo to listOf( xingHuoApiVersionParam, - xingHuoAppIDParam, - xingHuoApiKeyParam, - xingHuoApiSecretParam, - ), + xingHuoAppIDParam, + xingHuoApiKeyParam, + xingHuoApiSecretParam, + ), ) @@ -186,6 +188,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { gitLabTokenParam.value = gitlabToken gitLabUrlParam.value = gitlabUrl openAIKeyParam.value = openAiKey + customModelParam.value = customModel customOpenAIHostParam.value = customOpenAiHost customEngineServerParam.value = customEngineServer customEngineResponseTypeParam.value = customEngineResponseType @@ -212,6 +215,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { gitlabUrl = gitLabUrlParam.value gitlabToken = gitLabTokenParam.value openAiKey = openAIKeyParam.value + customModel = customModelParam.value customOpenAiHost = customOpenAIHostParam.value xingHuoApiSecrect = xingHuoApiSecretParam.value xingHuoApiVersion = XingHuoApiVersion.of(xingHuoApiVersionParam.value) @@ -237,6 +241,7 @@ class LLMSettingComponent(private val settings: AutoDevSettingsState) { settings.gitlabUrl != gitLabUrlParam.value || settings.gitlabToken != gitLabTokenParam.value || settings.openAiKey != openAIKeyParam.value || + settings.customModel != customModelParam.value || settings.xingHuoApiSecrect != xingHuoApiSecretParam.value || settings.xingHuoApiVersion != XingHuoApiVersion.of(xingHuoApiVersionParam.value) || settings.xingHuoAppId != xingHuoAppIDParam.value || diff --git a/src/main/resources/messages/AutoDevBundle.properties b/src/main/resources/messages/AutoDevBundle.properties index 3132dbc615..ac1fa11de6 100644 --- a/src/main/resources/messages/AutoDevBundle.properties +++ b/src/main/resources/messages/AutoDevBundle.properties @@ -56,6 +56,7 @@ settings.xingHuoApiKeyParam=XingHuo API Key settings.xingHuoApiSecretParam=XingHuo API Secret settings.xingHuoAppIDParam=XingHuo App ID settings.xingHuoApiVersionParam=XingHuo API Version +settings.customModelParam= Custom Model settings.delaySecondsParam=Quest Delay Seconds settings.customEngineResponseFormatParam=Custom Response Format (Json Path)