123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- // Copyright (c) 2017 VMware, Inc. All Rights Reserved.
- //
- // This product is licensed to you under the Apache License, Version 2.0 (the "License").
- // You may not use this product except in compliance with the License.
- //
- // This product may include a number of subcomponents with separate copyright notices and
- // license terms. Your use of these subcomponents is subject to the terms and conditions
- // of the subcomponent's license, as noted in the LICENSE file.
- // +build windows
- package lightwave
- import (
- "encoding/base64"
- "fmt"
- "github.com/vmware/photon-controller-go-sdk/SSPI"
- "math/rand"
- "net"
- "net/url"
- "strings"
- "time"
- )
- const gssTicketGrantFormatString = "grant_type=urn:vmware:grant_type:gss_ticket&gss_ticket=%s&context_id=%s&scope=%s"
- // GetTokensFromWindowsLogInContext gets tokens based on Windows logged in context
- // Here is how it works:
- // 1. Get the SPN (Service Principal Name) in the format host/FQDN of lightwave. This is needed for SSPI/Kerberos protocol
- // 2. Call Windows API AcquireCredentialsHandle() using SSPI library. This will give the current users credential handle
- // 3. Using this handle call Windows API AcquireCredentialsHandle(). This will give you byte[]
- // 4. Encode this byte[] and send it to OIDC server over HTTP (using POST)
- // 5. OIDC server can send either of the following
- // - Access tokens. In this case return access tokens to client
- // - Error in the format: invalid_grant: gss_continue_needed:'context id':'token from server'
- // 6. In case you get error, parse it and get the token from server
- // 7. Feed this token to step 3 and repeat steps till you get the access tokens from server
- func (client *OIDCClient) GetTokensFromWindowsLogInContext() (tokens *OIDCTokenResponse, err error) {
- spn, err := client.buildSPN()
- if err != nil {
- return nil, err
- }
- auth, _ := SSPI.GetAuth("", "", spn, "")
- userContext, err := auth.InitialBytes()
- if err != nil {
- return nil, err
- }
- // In case of multiple req/res between client and server (as explained in above comment),
- // server needs to maintain the mapping of context id -> token
- // So we need to generate random string as a context id
- // If we use same context id for all the requests, results can be erroneous
- contextId := client.generateRandomString()
- body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
- tokens, err = client.getToken(body)
- for {
- if err == nil {
- break
- }
- // In case of error the response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
- gssToken := client.validateAndExtractGSSResponse(err, contextId)
- if gssToken == "" {
- return nil, err
- }
- data, err := base64.StdEncoding.DecodeString(gssToken)
- if err != nil {
- return nil, err
- }
- userContext, err := auth.NextBytes(data)
- body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
- tokens, err = client.getToken(body)
- }
- return tokens, err
- }
- // Gets the SPN (Service Principal Name) in the format host/FQDN of lightwave
- func (client *OIDCClient) buildSPN() (spn string, err error) {
- u, err := url.Parse(client.Endpoint)
- if err != nil {
- return "", err
- }
- host, _, err := net.SplitHostPort(u.Host)
- if err != nil {
- return "", err
- }
- addr, err := net.LookupAddr(host)
- if err != nil {
- return "", err
- }
- var s = strings.TrimSuffix(addr[0], ".")
- return "host/" + s, nil
- }
- // validateAndExtractGSSResponse parse the error from server and returns token from server
- // In case of error from the server, response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
- // So, we check for the above format in error and then return the token from server
- // If error is not in above format, we return empty string
- func (client *OIDCClient) validateAndExtractGSSResponse(err error, contextId string) string {
- parts := strings.Split(err.Error(), ":")
- if !(len(parts) == 4 && strings.TrimSpace(parts[1]) == "gss_continue_needed" && parts[2] == contextId) {
- return ""
- } else {
- return parts[3]
- }
- }
- func (client *OIDCClient) generateRandomString() string {
- const length = 10
- const asciiA = 65
- const asciiZ = 90
- rand.Seed(time.Now().UTC().UnixNano())
- bytes := make([]byte, length)
- for i := 0; i < length; i++ {
- bytes[i] = byte(randInt(asciiA, asciiZ))
- }
- return string(bytes)
- }
- func randInt(min int, max int) int {
- return min + rand.Intn(max-min)
- }
|