Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

app engine support #60

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 59 additions & 26 deletions oauth.go
Expand Up @@ -56,6 +56,9 @@ import (
"strings"
"sync"
"time"

"gopkg.in/webhelp.v1/whcompat"
"golang.org/x/net/context"
)

const (
Expand Down Expand Up @@ -196,6 +199,9 @@ type Consumer struct {
// Defaults to http.Client{}, can be overridden (e.g. for testing) as necessary
HttpClient HttpClient

// If HttpClientFunc is set, will be used instead of HttpClient.
HttpClientFunc func(ctx context.Context) (HttpClient, error)

// Some APIs (e.g. Intuit/Quickbooks) require sending additional headers along with
// requests. (like "Accept" to specify the response type as XML or JSON) Note that this
// will only *add* headers, not set existing ones.
Expand Down Expand Up @@ -383,10 +389,14 @@ func NewCustomRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey,
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) GetRequestTokenAndUrl(callbackUrl string) (rtoken *RequestToken, loginUrl string, err error) {
return c.GetRequestTokenAndUrlWithParams(callbackUrl, c.AdditionalParams)
return c.GetRequestTokenAndUrlWithParamsCtx(context.TODO(), callbackUrl, c.AdditionalParams)
}

func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) {
return c.GetRequestTokenAndUrlWithParamsCtx(context.TODO(), callbackUrl, additionalParams)
}

func (c *Consumer) GetRequestTokenAndUrlWithParamsCtx(ctx context.Context, callbackUrl string, additionalParams map[string]string) (rtoken *RequestToken, loginUrl string, err error) {
params := c.baseParams(c.consumerKey, additionalParams)
if callbackUrl != "" {
params.Add(CALLBACK_PARAM, callbackUrl)
Expand All @@ -401,7 +411,7 @@ func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additiona
return nil, "", err
}

resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params)
resp, err := c.getBody(ctx, c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params)
if err != nil {
return nil, "", errors.New("getBody: " + err.Error())
}
Expand Down Expand Up @@ -440,15 +450,19 @@ func (c *Consumer) GetRequestTokenAndUrlWithParams(callbackUrl string, additiona
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) AuthorizeToken(rtoken *RequestToken, verificationCode string) (atoken *AccessToken, err error) {
return c.AuthorizeTokenWithParams(rtoken, verificationCode, c.AdditionalParams)
return c.AuthorizeTokenWithParamsCtx(context.TODO(), rtoken, verificationCode, c.AdditionalParams)
}

func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) {
return c.AuthorizeTokenWithParamsCtx(context.TODO(), rtoken, verificationCode, additionalParams)
}

func (c *Consumer) AuthorizeTokenWithParamsCtx(ctx context.Context, rtoken *RequestToken, verificationCode string, additionalParams map[string]string) (atoken *AccessToken, err error) {
params := map[string]string{
VERIFIER_PARAM: verificationCode,
TOKEN_PARAM: rtoken.Token,
}
return c.makeAccessTokenRequestWithParams(params, rtoken.Secret, additionalParams)
return c.makeAccessTokenRequestWithParams(ctx, params, rtoken.Secret, additionalParams)
}

// Use the service provider to refresh the AccessToken for a given session.
Expand All @@ -472,6 +486,10 @@ func (c *Consumer) AuthorizeTokenWithParams(rtoken *RequestToken, verificationCo
// Set if accessToken does not contain the SESSION_HANDLE_PARAM needed to
// refresh the token, or if an error occurred when making the request.
func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, err error) {
return c.RefreshTokenCtx(context.TODO(), accessToken)
}

func (c *Consumer) RefreshTokenCtx(ctx context.Context, accessToken *AccessToken) (atoken *AccessToken, err error) {
params := make(map[string]string)
sessionHandle, ok := accessToken.AdditionalData[SESSION_HANDLE_PARAM]
if !ok {
Expand All @@ -480,7 +498,7 @@ func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken,
params[SESSION_HANDLE_PARAM] = sessionHandle
params[TOKEN_PARAM] = accessToken.Token

return c.makeAccessTokenRequest(params, accessToken.Secret)
return c.makeAccessTokenRequest(ctx, params, accessToken.Secret)
}

// Use the service provider to obtain an AccessToken for a given session
Expand All @@ -497,11 +515,11 @@ func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken,
//
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret string) (atoken *AccessToken, err error) {
return c.makeAccessTokenRequestWithParams(params, secret, c.AdditionalParams)
func (c *Consumer) makeAccessTokenRequest(ctx context.Context, params map[string]string, secret string) (atoken *AccessToken, err error) {
return c.makeAccessTokenRequestWithParams(ctx, params, secret, c.AdditionalParams)
}

func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) {
func (c *Consumer) makeAccessTokenRequestWithParams(ctx context.Context, params map[string]string, secret string, additionalParams map[string]string) (atoken *AccessToken, err error) {
orderedParams := c.baseParams(c.consumerKey, additionalParams)
for key, value := range params {
orderedParams.Add(key, value)
Expand All @@ -516,7 +534,7 @@ func (c *Consumer) makeAccessTokenRequestWithParams(params map[string]string, se
return nil, err
}

resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams)
resp, err := c.getBody(ctx, c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -559,7 +577,7 @@ func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) {
// - err:
// Set only if there was an error, nil otherwise.
func (c *Consumer) Get(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("GET", url, LOC_URL, "", userParams, token)
return c.makeAuthorizedRequest(context.TODO(), "GET", url, LOC_URL, "", userParams, token)
}

func encodeUserParams(userParams map[string]string) string {
Expand All @@ -585,40 +603,40 @@ func (c *Consumer) Post(url string, userParams map[string]string, token *AccessT
// ** DEPRECATED **
// Please call "Post" on the http client returned by MakeHttpClient instead
func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token)
return c.makeAuthorizedRequest(context.TODO(), "POST", url, LOC_BODY, body, userParams, token)
}

// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and set the "Content-Type" header explicitly in the http.Request)
func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token)
return c.makeAuthorizedRequest(context.TODO(), "POST", url, LOC_JSON, body, nil, token)
}

// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and set the "Content-Type" header explicitly in the http.Request)
func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token)
return c.makeAuthorizedRequest(context.TODO(), "POST", url, LOC_XML, body, nil, token)
}

// ** DEPRECATED **
// Please call "Do" on the http client returned by MakeHttpClient instead
// (and setup the multipart data explicitly in the http.Request)
func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token)
return c.makeAuthorizedRequestReader(context.TODO(), "POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token)
}

// ** DEPRECATED **
// Please call "Delete" on the http client returned by MakeHttpClient instead
func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token)
return c.makeAuthorizedRequest(context.TODO(), "DELETE", url, LOC_URL, "", userParams, token)
}

// ** DEPRECATED **
// Please call "Put" on the http client returned by MakeHttpClient instead
func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token)
return c.makeAuthorizedRequest(context.TODO(), "PUT", url, LOC_URL, body, userParams, token)
}

func (c *Consumer) Debug(enabled bool) {
Expand All @@ -642,19 +660,19 @@ func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// consumer.Post() etc), and the new API (which takes actual http.Requests)
//
// So, here we construct the appropriate HTTP request for the inputs.
func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
func (c *Consumer) makeAuthorizedRequestReader(ctx context.Context, method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
urlObject, err := url.Parse(urlString)
if err != nil {
return nil, err
}

request := &http.Request{
request := whcompat.WithContext(&http.Request{
Method: method,
URL: urlObject,
Header: http.Header{},
Body: body,
ContentLength: int64(contentLength),
}
}, ctx)

vals := url.Values{}
for k, v := range userParams {
Expand Down Expand Up @@ -926,17 +944,21 @@ func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, er
fmt.Printf("Request: %v\n", serverRequest)
}

resp, err := rt.consumer.HttpClient.Do(serverRequest)
client, err := rt.consumer.httpClient(whcompat.Context(userRequest))
if err != nil {
return nil, errors.New("httpClient: " + err.Error())
}

resp, err := client.Do(serverRequest)
if err != nil {
return resp, err
}

return resp, nil
}

func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequestReader(method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token)
func (c *Consumer) makeAuthorizedRequest(ctx context.Context, method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) {
return c.makeAuthorizedRequestReader(ctx, method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token)
}

type request struct {
Expand Down Expand Up @@ -1215,8 +1237,8 @@ func (c *Consumer) requestString(method string, url string, params *OrderedParam
return result
}

func (c *Consumer) getBody(method, url string, oauthParams *OrderedParams) (*string, error) {
resp, err := c.httpExecute(method, url, "", 0, nil, oauthParams)
func (c *Consumer) getBody(ctx context.Context, method, url string, oauthParams *OrderedParams) (*string, error) {
resp, err := c.httpExecute(ctx, method, url, "", 0, nil, oauthParams)
if err != nil {
return nil, errors.New("httpExecute: " + err.Error())
}
Expand Down Expand Up @@ -1254,7 +1276,14 @@ func (e HTTPExecuteError) Error() string {
"\tRequest Headers: " + e.RequestHeaders
}

func (c *Consumer) httpExecute(
func (c *Consumer) httpClient(ctx context.Context) (HttpClient, error) {
if c.HttpClientFunc != nil {
return c.HttpClientFunc(ctx)
}
return c.HttpClient, nil
}

func (c *Consumer) httpExecute(ctx context.Context,
method string, urlStr string, contentType string, contentLength int, body io.Reader, oauthParams *OrderedParams) (*http.Response, error) {
// Create base request.
req, err := http.NewRequest(method, urlStr, body)
Expand Down Expand Up @@ -1295,7 +1324,11 @@ func (c *Consumer) httpExecute(
if c.debug {
fmt.Printf("Request: %v\n", req)
}
resp, err := c.HttpClient.Do(req)
client, err := c.httpClient(ctx)
if err != nil {
return nil, errors.New("httpClient: " + err.Error())
}
resp, err := client.Do(req)
if err != nil {
return nil, errors.New("Do: " + err.Error())
}
Expand Down