oidcclient.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. // Copyright (c) 2016 VMware, Inc. All Rights Reserved.
  2. //
  3. // This product is licensed to you under the Apache License, Version 2.0 (the "License").
  4. // You may not use this product except in compliance with the License.
  5. //
  6. // This product may include a number of subcomponents with separate copyright notices and
  7. // license terms. Your use of these subcomponents is subject to the terms and conditions
  8. // of the subcomponent's license, as noted in the LICENSE file.
  9. package lightwave
  10. import (
  11. "crypto/tls"
  12. "crypto/x509"
  13. "encoding/json"
  14. "encoding/pem"
  15. "fmt"
  16. "io/ioutil"
  17. "log"
  18. "net/http"
  19. "net/url"
  20. "strings"
  21. )
  22. const tokenScope string = "openid offline_access"
  23. type OIDCClient struct {
  24. httpClient *http.Client
  25. logger *log.Logger
  26. Endpoint string
  27. Options *OIDCClientOptions
  28. }
  29. type OIDCClientOptions struct {
  30. // Whether or not to ignore any TLS errors when talking to photon,
  31. // false by default.
  32. IgnoreCertificate bool
  33. // List of root CA's to use for server validation
  34. // nil by default.
  35. RootCAs *x509.CertPool
  36. // The scope values to use when requesting tokens
  37. TokenScope string
  38. }
  39. func NewOIDCClient(endpoint string, options *OIDCClientOptions, logger *log.Logger) (c *OIDCClient) {
  40. if logger == nil {
  41. logger = log.New(ioutil.Discard, "", log.LstdFlags)
  42. }
  43. options = buildOptions(options)
  44. tr := &http.Transport{
  45. TLSClientConfig: &tls.Config{
  46. InsecureSkipVerify: options.IgnoreCertificate,
  47. RootCAs: options.RootCAs},
  48. }
  49. c = &OIDCClient{
  50. httpClient: &http.Client{Transport: tr},
  51. logger: logger,
  52. Endpoint: strings.TrimRight(endpoint, "/"),
  53. Options: options,
  54. }
  55. return
  56. }
  57. func buildOptions(options *OIDCClientOptions) (result *OIDCClientOptions) {
  58. result = &OIDCClientOptions{
  59. TokenScope: tokenScope,
  60. }
  61. if options == nil {
  62. return
  63. }
  64. result.IgnoreCertificate = options.IgnoreCertificate
  65. if options.RootCAs != nil {
  66. result.RootCAs = options.RootCAs
  67. }
  68. if options.TokenScope != "" {
  69. result.TokenScope = options.TokenScope
  70. }
  71. return
  72. }
  73. func (client *OIDCClient) buildUrl(path string) (url string) {
  74. return fmt.Sprintf("%s%s", client.Endpoint, path)
  75. }
  76. // Cert download helper
  77. const certDownloadPath string = "/afd/vecs/ssl"
  78. type lightWaveCert struct {
  79. Value string `json:"encoded"`
  80. }
  81. func (client *OIDCClient) GetRootCerts() (certList []*x509.Certificate, err error) {
  82. // turn TLS verification off for
  83. originalTr := client.httpClient.Transport
  84. defer client.setTransport(originalTr)
  85. tr := &http.Transport{
  86. TLSClientConfig: &tls.Config{
  87. InsecureSkipVerify: true,
  88. },
  89. }
  90. client.setTransport(tr)
  91. // get the certs
  92. resp, err := client.httpClient.Get(client.buildUrl(certDownloadPath))
  93. if err != nil {
  94. return
  95. }
  96. defer resp.Body.Close()
  97. if resp.StatusCode != 200 {
  98. err = fmt.Errorf("Unexpected error retrieving auth server certs: %v %s", resp.StatusCode, resp.Status)
  99. return
  100. }
  101. // parse the certs
  102. certsData := &[]lightWaveCert{}
  103. err = json.NewDecoder(resp.Body).Decode(certsData)
  104. if err != nil {
  105. return
  106. }
  107. certList = make([]*x509.Certificate, len(*certsData))
  108. for idx, cert := range *certsData {
  109. block, _ := pem.Decode([]byte(cert.Value))
  110. if block == nil {
  111. err = fmt.Errorf("Unexpected response format: %v", certsData)
  112. return nil, err
  113. }
  114. decodedCert, err := x509.ParseCertificate(block.Bytes)
  115. if err != nil {
  116. return nil, err
  117. }
  118. certList[idx] = decodedCert
  119. }
  120. return
  121. }
  122. func (client *OIDCClient) setTransport(tr http.RoundTripper) {
  123. client.httpClient.Transport = tr
  124. }
  125. // Toke request helpers
  126. const tokenPath string = "/openidconnect/token"
  127. const passwordGrantFormatString = "grant_type=password&username=%s&password=%s&scope=%s"
  128. const refreshTokenGrantFormatString = "grant_type=refresh_token&refresh_token=%s"
  129. type OIDCTokenResponse struct {
  130. AccessToken string `json:"access_token"`
  131. ExpiresIn int `json:"expires_in"`
  132. RefreshToken string `json:"refresh_token,omitempty"`
  133. IdToken string `json:"id_token"`
  134. TokenType string `json:"token_type"`
  135. }
  136. func (client *OIDCClient) GetTokenByPasswordGrant(username string, password string) (tokens *OIDCTokenResponse, err error) {
  137. username = url.QueryEscape(username)
  138. password = url.QueryEscape(password)
  139. body := fmt.Sprintf(passwordGrantFormatString, username, password, client.Options.TokenScope)
  140. return client.getToken(body)
  141. }
  142. func (client *OIDCClient) GetTokenByRefreshTokenGrant(refreshToken string) (tokens *OIDCTokenResponse, err error) {
  143. body := fmt.Sprintf(refreshTokenGrantFormatString, refreshToken)
  144. return client.getToken(body)
  145. }
  146. func (client *OIDCClient) getToken(body string) (tokens *OIDCTokenResponse, err error) {
  147. request, err := http.NewRequest("POST", client.buildUrl(tokenPath), strings.NewReader(body))
  148. if err != nil {
  149. return nil, err
  150. }
  151. request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
  152. resp, err := client.httpClient.Do(request)
  153. if err != nil {
  154. return nil, err
  155. }
  156. defer resp.Body.Close()
  157. err = client.checkResponse(resp)
  158. if err != nil {
  159. return nil, err
  160. }
  161. tokens = &OIDCTokenResponse{}
  162. err = json.NewDecoder(resp.Body).Decode(tokens)
  163. if err != nil {
  164. return nil, err
  165. }
  166. return
  167. }
  168. type OIDCError struct {
  169. Code string `json:"error"`
  170. Message string `json:"error_description"`
  171. }
  172. func (e OIDCError) Error() string {
  173. return fmt.Sprintf("%v: %v", e.Code, e.Message)
  174. }
  175. func (client *OIDCClient) checkResponse(response *http.Response) (err error) {
  176. if response.StatusCode/100 == 2 {
  177. return
  178. }
  179. respBody, readErr := ioutil.ReadAll(response.Body)
  180. if err != nil {
  181. return fmt.Errorf(
  182. "Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
  183. }
  184. var oidcErr OIDCError
  185. err = json.Unmarshal(respBody, &oidcErr)
  186. if err != nil {
  187. return fmt.Errorf(
  188. "Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
  189. }
  190. return oidcErr
  191. }