/
authority.go
137 lines (116 loc) · 3.09 KB
/
authority.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package adal
import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"golang.org/x/xerrors"
)
const (
instanceDiscoveryEndpoint = "https://login.windows.net/common/discovery/instance"
)
type Authority struct {
URL *url.URL
Host string
Tenant string
validated bool
}
func NewAuthority(urlStr string, validateAuthority bool) (*Authority, error) {
parsedURL, err := url.ParseRequestURI(urlStr)
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
if err := validateAuthorityURL(parsedURL); err != nil {
return nil, xerrors.Errorf("validate(url=%s): %w", parsedURL.String(), err)
}
host, tenant, err := parseAuthority(parsedURL)
if err != nil {
return nil, xerrors.Errorf("parse authority(url=%s): %w", parsedURL.String(), err)
}
return &Authority{
URL: parsedURL,
Host: host,
Tenant: tenant,
validated: !validateAuthority,
}, nil
}
func (a *Authority) baseURL() string {
return fmt.Sprintf("https://%s/%s", a.Host, a.Tenant)
}
func (a *Authority) AuthorityURL() string {
return a.baseURL() + "/oauth2/authorize"
}
func (a *Authority) TokenURL() string {
return a.baseURL() + "/oauth2/token"
}
func (a *Authority) DeviceURL() string {
return a.baseURL() + "/oauth2/devicecode"
}
func (a *Authority) Validate(httpClient *http.Client) error {
if a.validated {
return nil
}
host := a.URL.Host
for _, authorityHost := range wellKnownAuthorityHosts {
if host == authorityHost {
a.validated = true
return nil
}
}
u, err := url.ParseRequestURI(instanceDiscoveryEndpoint)
if err != nil {
panic(err)
}
query := url.Values{}
query.Add("authorization_endpoint", a.AuthorityURL())
query.Add("api-version", "1.0")
u.RawQuery = query.Encode()
resp, err := httpClient.Get(u.String())
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || 300 <= resp.StatusCode {
return xerrors.Errorf("instance discovery request(expected=2xx, actual=%s)", resp.Status)
}
var out struct {
TenantDiscoveryEndpoint string `json:"tenant_discovery_endpoint"`
}
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&out); err != nil {
return xerrors.Errorf("parse instance discovery response: %w", err)
}
_, _ = io.Copy(ioutil.Discard, resp.Body)
if len(out.TenantDiscoveryEndpoint) == 0 {
return xerrors.New("`tenant_discovery_endpoint` was not found")
}
a.validated = true
return nil
}
func (a *Authority) IsADFSAuthority() bool {
return strings.ToLower(a.Tenant) == "adfs"
}
func (a *Authority) Validated() bool {
return a.validated
}
func validateAuthorityURL(aURL *url.URL) error {
if aURL.Scheme != "https" {
return xerrors.New("the authority url must be an https endpoint")
}
if len(aURL.RawQuery) != 0 {
return xerrors.New("the authority url must not have a query string")
}
return nil
}
func parseAuthority(aURL *url.URL) (string, string, error) {
host := aURL.Host
pathParts := strings.Split(aURL.Path, "/")
if len(pathParts) == 1 || len(pathParts[1]) == 0 {
return "", "", xerrors.New("could not determine tenant")
}
tenant := pathParts[1]
return host, tenant, nil
}