Skip to content

Commit

Permalink
Merge pull request #7 from umatter/function_calls
Browse files Browse the repository at this point in the history
Update chat_completion to allow function calls
  • Loading branch information
umatter committed Aug 14, 2023
2 parents acfe924 + e3bdb46 commit b7c6545
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions R/chat_completion.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#' based on the conversation history and the specified model parameters.
#'
#' @param msgs A data.frame containing the chat history to generate text from or
#' a chatlog object.
#' a chatlog object
#' @param functions An optional list of functions to use for the function call.
#' @param function_call An optional list specifying the function call to use.
#' @param model A character string specifying the ID of the model to use.
#' The default value is "gpt-3.5-turbo".
#' @param temperature An optional numeric scalar specifying the sampling
Expand Down Expand Up @@ -56,10 +58,11 @@
#' chat_completion(msgs_df)
#' }
#' @export
chat_completion <- function(msgs, model = "gpt-3.5-turbo", temperature = NULL,
max_tokens = NULL, n = NULL, stop = NULL,presence_penalty = NULL,
frequency_penalty = NULL, best_of = NULL, logit_bias = NULL, stream = FALSE,
top_p = NULL, user = NULL) {
chat_completion <- function(msgs, functions = NULL, function_call = NULL,
model = "gpt-3.5-turbo", temperature = NULL, max_tokens = NULL, n = NULL,
stop = NULL, presence_penalty = NULL, frequency_penalty = NULL,
best_of = NULL, logit_bias = NULL, stream = FALSE, top_p = NULL,
user = NULL) {
# the relevant API endpoint
API_ENDPOINT <- "https://api.openai.com/v1/chat/completions"

Expand All @@ -86,6 +89,21 @@ chat_completion <- function(msgs, model = "gpt-3.5-turbo", temperature = NULL,
messages = msgs
)

if (!is.null(functions)) {
# Currently, only two models support the 'functions' argument
# TODO: stop or warn ?
supported_models <- c("gpt-3.5-turbo-0613", "gpt-4-0613")
if (!any(model %in% supported_models)) {
stop("The 'functions' argument is currently only supported for the ",
paste(supported_models, collapse = ", "), " models.")
}
payload$functions <- functions

if (!is.null(function_call)) {
payload$function_call <- function_call
}
}

if (!is.null(max_tokens)) {
payload$max_tokens <- max_tokens
}
Expand Down

0 comments on commit b7c6545

Please sign in to comment.