-
Notifications
You must be signed in to change notification settings - Fork 694
/
ConsoleChatGPT.java
162 lines (124 loc) · 4.79 KB
/
ConsoleChatGPT.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package com.plexpt.chatgpt;
import com.plexpt.chatgpt.entity.chat.Message;
import com.plexpt.chatgpt.listener.ConsoleStreamListener;
import com.plexpt.chatgpt.util.Proxys;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigDecimal;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;
import cn.hutool.core.util.NumberUtil;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
/**
* open ai 客户端
*
* @author plexpt
*/
@Slf4j
public class ConsoleChatGPT {
public static Proxy proxy = Proxy.NO_PROXY;
public static void main(String[] args) {
System.out.println("ChatGPT - Java command-line interface");
System.out.println("Press enter twice to submit your question.");
System.out.println();
System.out.println("按两次回车以提交您的问题!!!");
System.out.println("按两次回车以提交您的问题!!!");
System.out.println("按两次回车以提交您的问题!!!");
System.out.println();
System.out.println("Please enter APIKEY, press Enter twice to submit:");
String key = getInput("请输入APIKEY,按两次回车以提交:\n");
check(key);
// 询问用户是否使用代理 国内需要代理
System.out.println("是否使用代理?(y/n): ");
System.out.println("use proxy?(y/n): ");
String useProxy = getInput("按两次回车以提交:\n");
if (useProxy.equalsIgnoreCase("y")) {
// 输入代理地址
System.out.println("请输入代理类型(http/socks): ");
String type = getInput("按两次回车以提交:\n");
// 输入代理地址
System.out.println("请输入代理IP: ");
String proxyHost = getInput("按两次回车以提交:\n");
// 输入代理端口
System.out.println("请输入代理端口: ");
String portStr = getInput("按两次回车以提交:\n");
Integer proxyPort = Integer.parseInt(portStr);
if (type.equals("http")) {
proxy = Proxys.http(proxyHost, proxyPort);
} else {
proxy = Proxys.socks5(proxyHost, proxyPort);
}
}
// System.out.println("Inquiry balance...");
// System.out.println("查询余额中...");
// BigDecimal balance = getBalance(key);
// System.out.println("API KEY balance: " + balance.toPlainString());
//
// if (!NumberUtil.isGreater(balance, BigDecimal.ZERO)) {
// System.out.println("API KEY 余额不足: ");
// return;
// }
while (true) {
String prompt = getInput("\nYou:\n");
ChatGPTStream chatGPT = ChatGPTStream.builder()
.apiKey(key)
.proxy(proxy)
.build()
.init();
System.out.println("AI: ");
//卡住
CountDownLatch countDownLatch = new CountDownLatch(1);
Message message = Message.of(prompt);
ConsoleStreamListener listener = new ConsoleStreamListener() {
@Override
public void onError(Throwable throwable, String response) {
throwable.printStackTrace();
countDownLatch.countDown();
}
};
listener.setOnComplate(msg -> {
countDownLatch.countDown();
});
chatGPT.streamChatCompletion(Arrays.asList(message), listener);
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
private static BigDecimal getBalance(String key) {
ChatGPT chatGPT = ChatGPT.builder()
.apiKey(key)
.proxy(proxy)
.build()
.init();
return chatGPT.balance();
}
private static void check(String key) {
if (key == null || key.isEmpty()) {
throw new RuntimeException("请输入正确的KEY");
}
}
@SneakyThrows
public static String getInput(String prompt) {
System.out.print(prompt);
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
List<String> lines = new ArrayList<>();
String line;
try {
while ((line = reader.readLine()) != null && !line.isEmpty()) {
lines.add(line);
}
} catch (IOException e) {
e.printStackTrace();
}
return lines.stream().collect(Collectors.joining("\n"));
}
}