// Copyright (c) 2016 VMware, Inc. All Rights Reserved. // // This product is licensed to you under the Apache License, Version 2.0 (the "License"). // You may not use this product except in compliance with the License. // // This product may include a number of subcomponents with separate copyright notices and // license terms. Your use of these subcomponents is subject to the terms and conditions // of the subcomponent's license, as noted in the LICENSE file. package lightwave import ( "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "io/ioutil" "log" "net/http" "net/url" "strings" ) const tokenScope string = "openid offline_access" type OIDCClient struct { httpClient *http.Client logger *log.Logger Endpoint string Options *OIDCClientOptions } type OIDCClientOptions struct { // Whether or not to ignore any TLS errors when talking to photon, // false by default. IgnoreCertificate bool // List of root CA's to use for server validation // nil by default. RootCAs *x509.CertPool // The scope values to use when requesting tokens TokenScope string } func NewOIDCClient(endpoint string, options *OIDCClientOptions, logger *log.Logger) (c *OIDCClient) { if logger == nil { logger = log.New(ioutil.Discard, "", log.LstdFlags) } options = buildOptions(options) tr := &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: options.IgnoreCertificate, RootCAs: options.RootCAs}, } c = &OIDCClient{ httpClient: &http.Client{Transport: tr}, logger: logger, Endpoint: strings.TrimRight(endpoint, "/"), Options: options, } return } func buildOptions(options *OIDCClientOptions) (result *OIDCClientOptions) { result = &OIDCClientOptions{ TokenScope: tokenScope, } if options == nil { return } result.IgnoreCertificate = options.IgnoreCertificate if options.RootCAs != nil { result.RootCAs = options.RootCAs } if options.TokenScope != "" { result.TokenScope = options.TokenScope } return } func (client *OIDCClient) buildUrl(path string) (url string) { return fmt.Sprintf("%s%s", client.Endpoint, path) } // Cert download helper const certDownloadPath string = "/afd/vecs/ssl" type lightWaveCert struct { Value string `json:"encoded"` } func (client *OIDCClient) GetRootCerts() (certList []*x509.Certificate, err error) { // turn TLS verification off for originalTr := client.httpClient.Transport defer client.setTransport(originalTr) tr := &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, } client.setTransport(tr) // get the certs resp, err := client.httpClient.Get(client.buildUrl(certDownloadPath)) if err != nil { return } defer resp.Body.Close() if resp.StatusCode != 200 { err = fmt.Errorf("Unexpected error retrieving auth server certs: %v %s", resp.StatusCode, resp.Status) return } // parse the certs certsData := &[]lightWaveCert{} err = json.NewDecoder(resp.Body).Decode(certsData) if err != nil { return } certList = make([]*x509.Certificate, len(*certsData)) for idx, cert := range *certsData { block, _ := pem.Decode([]byte(cert.Value)) if block == nil { err = fmt.Errorf("Unexpected response format: %v", certsData) return nil, err } decodedCert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, err } certList[idx] = decodedCert } return } func (client *OIDCClient) setTransport(tr http.RoundTripper) { client.httpClient.Transport = tr } // Toke request helpers const tokenPath string = "/openidconnect/token" const passwordGrantFormatString = "grant_type=password&username=%s&password=%s&scope=%s" const refreshTokenGrantFormatString = "grant_type=refresh_token&refresh_token=%s" type OIDCTokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token,omitempty"` IdToken string `json:"id_token"` TokenType string `json:"token_type"` } func (client *OIDCClient) GetTokenByPasswordGrant(username string, password string) (tokens *OIDCTokenResponse, err error) { username = url.QueryEscape(username) password = url.QueryEscape(password) body := fmt.Sprintf(passwordGrantFormatString, username, password, client.Options.TokenScope) return client.getToken(body) } func (client *OIDCClient) GetTokenByRefreshTokenGrant(refreshToken string) (tokens *OIDCTokenResponse, err error) { body := fmt.Sprintf(refreshTokenGrantFormatString, refreshToken) return client.getToken(body) } func (client *OIDCClient) getToken(body string) (tokens *OIDCTokenResponse, err error) { request, err := http.NewRequest("POST", client.buildUrl(tokenPath), strings.NewReader(body)) if err != nil { return nil, err } request.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, err := client.httpClient.Do(request) if err != nil { return nil, err } defer resp.Body.Close() err = client.checkResponse(resp) if err != nil { return nil, err } tokens = &OIDCTokenResponse{} err = json.NewDecoder(resp.Body).Decode(tokens) if err != nil { return nil, err } return } type OIDCError struct { Code string `json:"error"` Message string `json:"error_description"` } func (e OIDCError) Error() string { return fmt.Sprintf("%v: %v", e.Code, e.Message) } func (client *OIDCClient) checkResponse(response *http.Response) (err error) { if response.StatusCode/100 == 2 { return } respBody, readErr := ioutil.ReadAll(response.Body) if err != nil { return fmt.Errorf( "Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr) } var oidcErr OIDCError err = json.Unmarshal(respBody, &oidcErr) if err != nil { return fmt.Errorf( "Status: %v, Body: %v [%v]", response.Status, string(respBody[:]), readErr) } return oidcErr }