123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- // Copyright (c) 2016 VMware, Inc. All Rights Reserved.
- //
- // This product is licensed to you under the Apache License, Version 2.0 (the "License").
- // You may not use this product except in compliance with the License.
- //
- // This product may include a number of subcomponents with separate copyright notices and
- // license terms. Your use of these subcomponents is subject to the terms and conditions
- // of the subcomponent's license, as noted in the LICENSE file.
- package lightwave
- import (
- "crypto/tls"
- "crypto/x509"
- "encoding/json"
- "encoding/pem"
- "fmt"
- "io/ioutil"
- "log"
- "net/http"
- "net/url"
- "strings"
- )
- const tokenScope string = "openid offline_access"
- type OIDCClient struct {
- httpClient *http.Client
- logger *log.Logger
- Endpoint string
- Options *OIDCClientOptions
- }
- type OIDCClientOptions struct {
- // Whether or not to ignore any TLS errors when talking to photon,
- // false by default.
- IgnoreCertificate bool
- // List of root CA's to use for server validation
- // nil by default.
- RootCAs *x509.CertPool
- // The scope values to use when requesting tokens
- TokenScope string
- }
- func NewOIDCClient(endpoint string, options *OIDCClientOptions, logger *log.Logger) (c *OIDCClient) {
- if logger == nil {
- logger = log.New(ioutil.Discard, "", log.LstdFlags)
- }
- options = buildOptions(options)
- tr := &http.Transport{
- TLSClientConfig: &tls.Config{
- InsecureSkipVerify: options.IgnoreCertificate,
- RootCAs: options.RootCAs},
- }
- c = &OIDCClient{
- httpClient: &http.Client{Transport: tr},
- logger: logger,
- Endpoint: strings.TrimRight(endpoint, "/"),
- Options: options,
- }
- return
- }
- func buildOptions(options *OIDCClientOptions) (result *OIDCClientOptions) {
- result = &OIDCClientOptions{
- TokenScope: tokenScope,
- }
- if options == nil {
- return
- }
- result.IgnoreCertificate = options.IgnoreCertificate
- if options.RootCAs != nil {
- result.RootCAs = options.RootCAs
- }
- if options.TokenScope != "" {
- result.TokenScope = options.TokenScope
- }
- return
- }
- func (client *OIDCClient) buildUrl(path string) (url string) {
- return fmt.Sprintf("%s%s", client.Endpoint, path)
- }
- // Cert download helper
- const certDownloadPath string = "/afd/vecs/ssl"
- type lightWaveCert struct {
- Value string `json:"encoded"`
- }
- func (client *OIDCClient) GetRootCerts() (certList []*x509.Certificate, err error) {
- // turn TLS verification off for
- originalTr := client.httpClient.Transport
- defer client.setTransport(originalTr)
- tr := &http.Transport{
- TLSClientConfig: &tls.Config{
- InsecureSkipVerify: true,
- },
- }
- client.setTransport(tr)
- // get the certs
- resp, err := client.httpClient.Get(client.buildUrl(certDownloadPath))
- if err != nil {
- return
- }
- defer resp.Body.Close()
- if resp.StatusCode != 200 {
- err = fmt.Errorf("Unexpected error retrieving auth server certs: %v %s", resp.StatusCode, resp.Status)
- return
- }
- // parse the certs
- certsData := &[]lightWaveCert{}
- err = json.NewDecoder(resp.Body).Decode(certsData)
- if err != nil {
- return
- }
- certList = make([]*x509.Certificate, len(*certsData))
- for idx, cert := range *certsData {
- block, _ := pem.Decode([]byte(cert.Value))
- if block == nil {
- err = fmt.Errorf("Unexpected response format: %v", certsData)
- return nil, err
- }
- decodedCert, err := x509.ParseCertificate(block.Bytes)
- if err != nil {
- return nil, err
- }
- certList[idx] = decodedCert
- }
- return
- }
- func (client *OIDCClient) setTransport(tr http.RoundTripper) {
- client.httpClient.Transport = tr
- }
- // Toke request helpers
- const tokenPath string = "/openidconnect/token"
- const passwordGrantFormatString = "grant_type=password&username=%s&password=%s&scope=%s"
- const refreshTokenGrantFormatString = "grant_type=refresh_token&refresh_token=%s"
- type OIDCTokenResponse struct {
- AccessToken string `json:"access_token"`
- ExpiresIn int `json:"expires_in"`
- RefreshToken string `json:"refresh_token,omitempty"`
- IdToken string `json:"id_token"`
- TokenType string `json:"token_type"`
- }
- func (client *OIDCClient) GetTokenByPasswordGrant(username string, password string) (tokens *OIDCTokenResponse, err error) {
- username = url.QueryEscape(username)
- password = url.QueryEscape(password)
- body := fmt.Sprintf(passwordGrantFormatString, username, password, client.Options.TokenScope)
- return client.getToken(body)
- }
- func (client *OIDCClient) GetTokenByRefreshTokenGrant(refreshToken string) (tokens *OIDCTokenResponse, err error) {
- body := fmt.Sprintf(refreshTokenGrantFormatString, refreshToken)
- return client.getToken(body)
- }
- func (client *OIDCClient) getToken(body string) (tokens *OIDCTokenResponse, err error) {
- request, err := http.NewRequest("POST", client.buildUrl(tokenPath), strings.NewReader(body))
- if err != nil {
- return nil, err
- }
- request.Header.Add("Content-Type", "application/x-www-form-urlencoded")
- resp, err := client.httpClient.Do(request)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- err = client.checkResponse(resp)
- if err != nil {
- return nil, err
- }
- tokens = &OIDCTokenResponse{}
- err = json.NewDecoder(resp.Body).Decode(tokens)
- if err != nil {
- return nil, err
- }
- return
- }
- type OIDCError struct {
- Code string `json:"error"`
- Message string `json:"error_description"`
- }
- func (e OIDCError) Error() string {
- return fmt.Sprintf("%v: %v", e.Code, e.Message)
- }
- func (client *OIDCClient) checkResponse(response *http.Response) (err error) {
- if response.StatusCode/100 == 2 {
- return
- }
- respBody, readErr := ioutil.ReadAll(response.Body)
- if err != nil {
- return fmt.Errorf(
- "Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
- }
- var oidcErr OIDCError
- err = json.Unmarshal(respBody, &oidcErr)
- if err != nil {
- return fmt.Errorf(
- "Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr)
- }
- return oidcErr
- }
|