oidcclient_sspi.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. // Copyright (c) 2017 VMware, Inc. All Rights Reserved.
  2. //
  3. // This product is licensed to you under the Apache License, Version 2.0 (the "License").
  4. // You may not use this product except in compliance with the License.
  5. //
  6. // This product may include a number of subcomponents with separate copyright notices and
  7. // license terms. Your use of these subcomponents is subject to the terms and conditions
  8. // of the subcomponent's license, as noted in the LICENSE file.
  9. // +build windows
  10. package lightwave
  11. import (
  12. "encoding/base64"
  13. "fmt"
  14. "github.com/vmware/photon-controller-go-sdk/SSPI"
  15. "math/rand"
  16. "net"
  17. "net/url"
  18. "strings"
  19. "time"
  20. )
  21. const gssTicketGrantFormatString = "grant_type=urn:vmware:grant_type:gss_ticket&gss_ticket=%s&context_id=%s&scope=%s"
  22. // GetTokensFromWindowsLogInContext gets tokens based on Windows logged in context
  23. // Here is how it works:
  24. // 1. Get the SPN (Service Principal Name) in the format host/FQDN of lightwave. This is needed for SSPI/Kerberos protocol
  25. // 2. Call Windows API AcquireCredentialsHandle() using SSPI library. This will give the current users credential handle
  26. // 3. Using this handle call Windows API AcquireCredentialsHandle(). This will give you byte[]
  27. // 4. Encode this byte[] and send it to OIDC server over HTTP (using POST)
  28. // 5. OIDC server can send either of the following
  29. // - Access tokens. In this case return access tokens to client
  30. // - Error in the format: invalid_grant: gss_continue_needed:'context id':'token from server'
  31. // 6. In case you get error, parse it and get the token from server
  32. // 7. Feed this token to step 3 and repeat steps till you get the access tokens from server
  33. func (client *OIDCClient) GetTokensFromWindowsLogInContext() (tokens *OIDCTokenResponse, err error) {
  34. spn, err := client.buildSPN()
  35. if err != nil {
  36. return nil, err
  37. }
  38. auth, _ := SSPI.GetAuth("", "", spn, "")
  39. userContext, err := auth.InitialBytes()
  40. if err != nil {
  41. return nil, err
  42. }
  43. // In case of multiple req/res between client and server (as explained in above comment),
  44. // server needs to maintain the mapping of context id -> token
  45. // So we need to generate random string as a context id
  46. // If we use same context id for all the requests, results can be erroneous
  47. contextId := client.generateRandomString()
  48. body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
  49. tokens, err = client.getToken(body)
  50. for {
  51. if err == nil {
  52. break
  53. }
  54. // In case of error the response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
  55. gssToken := client.validateAndExtractGSSResponse(err, contextId)
  56. if gssToken == "" {
  57. return nil, err
  58. }
  59. data, err := base64.StdEncoding.DecodeString(gssToken)
  60. if err != nil {
  61. return nil, err
  62. }
  63. userContext, err := auth.NextBytes(data)
  64. body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
  65. tokens, err = client.getToken(body)
  66. }
  67. return tokens, err
  68. }
  69. // Gets the SPN (Service Principal Name) in the format host/FQDN of lightwave
  70. func (client *OIDCClient) buildSPN() (spn string, err error) {
  71. u, err := url.Parse(client.Endpoint)
  72. if err != nil {
  73. return "", err
  74. }
  75. host, _, err := net.SplitHostPort(u.Host)
  76. if err != nil {
  77. return "", err
  78. }
  79. addr, err := net.LookupAddr(host)
  80. if err != nil {
  81. return "", err
  82. }
  83. var s = strings.TrimSuffix(addr[0], ".")
  84. return "host/" + s, nil
  85. }
  86. // validateAndExtractGSSResponse parse the error from server and returns token from server
  87. // In case of error from the server, response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
  88. // So, we check for the above format in error and then return the token from server
  89. // If error is not in above format, we return empty string
  90. func (client *OIDCClient) validateAndExtractGSSResponse(err error, contextId string) string {
  91. parts := strings.Split(err.Error(), ":")
  92. if !(len(parts) == 4 && strings.TrimSpace(parts[1]) == "gss_continue_needed" && parts[2] == contextId) {
  93. return ""
  94. } else {
  95. return parts[3]
  96. }
  97. }
  98. func (client *OIDCClient) generateRandomString() string {
  99. const length = 10
  100. const asciiA = 65
  101. const asciiZ = 90
  102. rand.Seed(time.Now().UTC().UnixNano())
  103. bytes := make([]byte, length)
  104. for i := 0; i < length; i++ {
  105. bytes[i] = byte(randInt(asciiA, asciiZ))
  106. }
  107. return string(bytes)
  108. }
  109. func randInt(min int, max int) int {
  110. return min + rand.Intn(max-min)
  111. }