package storageos

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"math/rand"
	"net"
	"net/http"
	"net/url"
	"reflect"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/storageos/go-api/netutil"
	"github.com/storageos/go-api/serror"
)

const (
	// DefaultUserAgent is the default User-Agent header to include in HTTP requests.
	DefaultUserAgent = "go-storageosclient"
	// DefaultVersionStr is the string value of the default API version.
	DefaultVersionStr = "1"
	// DefaultVersion is the default API version.
	DefaultVersion = 1
)

var (
	// ErrConnectionRefused is returned when the client cannot connect to the given endpoint.
	ErrConnectionRefused = errors.New("cannot connect to StorageOS API endpoint")

	// ErrInactivityTimeout is returned when a streamable call has been inactive for some time.
	ErrInactivityTimeout = errors.New("inactivity time exceeded timeout")

	// ErrInvalidVersion is returned when a versioned client was requested but no version specified.
	ErrInvalidVersion = errors.New("invalid version")

	// ErrProxyNotSupported is returned when a client is unable to set a proxy for http requests.
	ErrProxyNotSupported = errors.New("client does not support http proxy")

	// DefaultPort is the default API port.
	DefaultPort = "5705"

	// DataplaneHealthPort is the the port used by the dataplane health-check service.
	DataplaneHealthPort = "5704"

	// DefaultHost is the default API host.
	DefaultHost = "http://localhost:" + DefaultPort
)

// APIVersion is an internal representation of a version of the Remote API.
type APIVersion int

// NewAPIVersion returns an instance of APIVersion for the given string.
//
// The given string must be in the form <major>
func NewAPIVersion(input string) (APIVersion, error) {
	if input == "" {
		return DefaultVersion, ErrInvalidVersion
	}
	version, err := strconv.Atoi(input)
	if err != nil {
		return 0, fmt.Errorf("Unable to parse version %q", input)
	}
	return APIVersion(version), nil
}

func (version APIVersion) String() string {
	return fmt.Sprintf("v%d", version)
}

// Client is the basic type of this package. It provides methods for
// interaction with the API.
type Client struct {
	httpClient *http.Client

	addresses []string
	username  string
	secret    string
	userAgent string

	configLock  *sync.RWMutex // Lock for config changes
	addressLock *sync.Mutex   // Lock used to copy/update the address slice

	requestedAPIVersion APIVersion
	serverAPIVersion    APIVersion
	expectedAPIVersion  APIVersion

	SkipServerVersionCheck bool
}

// ClientVersion returns the API version of the client
func (c *Client) ClientVersion() string {
	return DefaultVersionStr
}

// Dialer is an interface that allows network connections to be dialed
// (net.Dialer fulfills this interface) and named pipes (a shim using
// winio.DialPipe)
type Dialer interface {
	Dial(network, address string) (net.Conn, error)
}

// NewClient returns a Client instance ready for communication with the given
// server endpoint. It will use the latest remote API version available in the
// server.
func NewClient(nodes string) (*Client, error) {
	client, err := NewVersionedClient(nodes, "")
	if err != nil {
		return nil, err
	}
	client.SkipServerVersionCheck = true
	client.userAgent = DefaultUserAgent
	return client, nil
}

// NewVersionedClient returns a Client instance ready for communication with
// the given server endpoint, using a specific remote API version.
func NewVersionedClient(nodestring string, apiVersionString string) (*Client, error) {
	nodes := strings.Split(nodestring, ",")
	addresses, err := netutil.AddressesFromNodes(nodes)
	if err != nil {
		return nil, err
	}

	if len(addresses) > 1 {
		// Shuffle returned addresses in attempt to spread the load
		rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
		rnd.Shuffle(len(addresses), func(i, j int) {
			addresses[i], addresses[j] = addresses[j], addresses[i]
		})
	}

	client := &Client{
		httpClient:  defaultClient(),
		addresses:   addresses,
		configLock:  &sync.RWMutex{},
		addressLock: &sync.Mutex{},
	}

	if apiVersionString != "" {
		version, err := strconv.Atoi(apiVersionString)
		if err != nil {
			return nil, err
		}
		client.requestedAPIVersion = APIVersion(version)
	}

	return client, nil
}

// SetUserAgent sets the client useragent.
func (c *Client) SetUserAgent(useragent string) {
	c.configLock.Lock()
	defer c.configLock.Unlock()
	c.userAgent = useragent
}

// SetAuth sets the API username and secret to be used for all API requests.
// It should not be called concurrently with any other Client methods.
func (c *Client) SetAuth(username string, secret string) {
	c.configLock.Lock()
	defer c.configLock.Unlock()
	if username != "" {
		c.username = username
	}
	if secret != "" {
		c.secret = secret
	}
}

// SetProxy will set the proxy URL for both the HTTPClient.
// If the transport method does not support usage
// of proxies, an error will be returned.
func (c *Client) SetProxy(proxy *url.URL) error {
	c.configLock.Lock()
	defer c.configLock.Unlock()

	if client := c.httpClient; client != nil {
		transport, supported := client.Transport.(*http.Transport)
		if !supported {
			return ErrProxyNotSupported
		}
		transport.Proxy = http.ProxyURL(proxy)
	}
	return nil
}

// SetTimeout takes a timeout and applies it to both the HTTPClient and
// nativeHTTPClient. It should not be called concurrently with any other Client
// methods.
func (c *Client) SetTimeout(t time.Duration) {
	c.configLock.Lock()
	defer c.configLock.Unlock()
	if c.httpClient != nil {
		c.httpClient.Timeout = t
	}
}

func (c *Client) checkAPIVersion() error {
	serverAPIVersionString, err := c.getServerAPIVersionString()
	if err != nil {
		return err
	}
	c.serverAPIVersion, err = NewAPIVersion(serverAPIVersionString)
	if err != nil {
		return err
	}
	c.configLock.Lock()
	defer c.configLock.Unlock()
	if c.requestedAPIVersion == 0 {
		c.expectedAPIVersion = c.serverAPIVersion
	} else {
		c.expectedAPIVersion = c.requestedAPIVersion
	}
	return nil
}

// Ping pings the API server
//
// See https://goo.gl/wYfgY1 for more details.
func (c *Client) Ping() error {
	urlpath := "/_ping"
	resp, err := c.do("GET", urlpath, doOptions{})
	if err != nil {
		return err
	}
	if resp.StatusCode != http.StatusOK {
		return newError(resp)
	}
	return resp.Body.Close()
}

func (c *Client) getServerAPIVersionString() (version string, err error) {
	v, err := c.ServerVersion(context.Background())
	if err != nil {
		return "", err
	}
	return v.APIVersion, nil
}

type doOptions struct {
	context context.Context
	data    interface{}

	values  url.Values
	headers map[string]string

	fieldSelector string
	labelSelector string
	namespace     string

	forceJSON   bool
	force       bool
	unversioned bool
}

func (c *Client) do(method, urlpath string, doOptions doOptions) (*http.Response, error) {
	var params io.Reader
	if doOptions.data != nil || doOptions.forceJSON {
		buf, err := json.Marshal(doOptions.data)
		if err != nil {
			return nil, err
		}
		params = bytes.NewBuffer(buf)
	}

	// Prefix the path with the namespace if given.  The caller should only set
	// the namespace if this is desired.
	if doOptions.namespace != "" {
		urlpath = "/" + NamespaceAPIPrefix + "/" + doOptions.namespace + "/" + urlpath
	}

	if !c.SkipServerVersionCheck && !doOptions.unversioned {
		err := c.checkAPIVersion()
		if err != nil {
			return nil, err
		}
	}

	query := url.Values{}
	if doOptions.values != nil {
		query = doOptions.values
	}
	if doOptions.force {
		query.Add("force", "1")
	}

	// Obtain a reader lock to prevent the http client from being
	// modified underneath us during a do().
	c.configLock.RLock()
	defer c.configLock.RUnlock() // This defer matches both the initial and the above lock

	httpClient := c.httpClient
	endpoint := c.getAPIPath(urlpath, query, doOptions.unversioned)

	// The doOptions Context is shared for every attempted request in the do.
	ctx := doOptions.context
	if ctx == nil {
		ctx = context.Background()
	}

	var failedAddresses = map[string]struct{}{}

	c.addressLock.Lock()
	var addresses = make([]string, len(c.addresses))
	copy(addresses, c.addresses)
	c.addressLock.Unlock()

	for _, address := range addresses {
		target := address + endpoint

		req, err := http.NewRequest(method, target, params)
		if err != nil {
			// Probably should not try and continue if we're unable
			// to create the request.
			return nil, err
		}
		req.Header.Set("User-Agent", c.userAgent)
		if doOptions.data != nil {
			req.Header.Set("Content-Type", "application/json")
		} else if method == "POST" {
			req.Header.Set("Content-Type", "plain/text")
		}
		if c.username != "" && c.secret != "" {
			req.SetBasicAuth(c.username, c.secret)
		}

		for k, v := range doOptions.headers {
			req.Header.Set(k, v)
		}

		resp, err := httpClient.Do(req.WithContext(ctx))
		if err != nil {
			// If it is a custom error, return it. It probably knows more than us
			if serror.IsStorageOSError(err) {
				switch serror.ErrorKind(err) {
				case serror.APIUncontactable:
					// If API isn't contactable we should try the next address
					failedAddresses[address] = struct{}{}
					continue
				case serror.InvalidHostConfig:
					// If invalid host or unknown error, we should report back
					fallthrough
				case serror.UnknownError:
					return nil, err
				}
			}

			select {
			case <-ctx.Done():
				return nil, ctx.Err()
			default:
				if _, ok := err.(net.Error); ok {
					// Be optimistic and try the next endpoint
					failedAddresses[address] = struct{}{}
					continue
				}
				return nil, err
			}
		}

		// If we get to the point of response, we should move any failed
		// addresses to the back.
		failed := len(failedAddresses)
		if failed > 0 {
			// Copy addresses we think are okay into the head of the list
			newOrder := make([]string, 0, len(addresses)-failed)

			for _, addr := range addresses {
				if _, exists := failedAddresses[addr]; !exists {
					newOrder = append(newOrder, addr)
				}
			}
			for addr := range failedAddresses {
				newOrder = append(newOrder, addr)
			}

			c.addressLock.Lock()
			// Bring in the new order
			c.addresses = newOrder
			c.addressLock.Unlock()
		}

		if resp.StatusCode < 200 || resp.StatusCode >= 400 {
			return nil, newError(resp) // These status codes are likely to be fatal
		}
		return resp, nil
	}

	return nil, netutil.ErrAllFailed(addresses)
}

func (c *Client) getAPIPath(path string, query url.Values, unversioned bool) string {
	var apiPath = strings.TrimLeft(path, "/")

	if !unversioned {
		apiPath = fmt.Sprintf("/%s/%s", c.requestedAPIVersion, apiPath)
	} else {
		apiPath = fmt.Sprintf("/%s", apiPath)
	}

	if len(query) > 0 {
		apiPath = apiPath + "?" + query.Encode()
	}

	return apiPath
}

func queryString(opts interface{}) string {
	if opts == nil {
		return ""
	}
	value := reflect.ValueOf(opts)
	if value.Kind() == reflect.Ptr {
		value = value.Elem()
	}
	if value.Kind() != reflect.Struct {
		return ""
	}
	items := url.Values(map[string][]string{})
	for i := 0; i < value.NumField(); i++ {
		field := value.Type().Field(i)
		if field.PkgPath != "" {
			continue
		}
		key := field.Tag.Get("qs")
		if key == "" {
			key = strings.ToLower(field.Name)
		} else if key == "-" {
			continue
		}
		addQueryStringValue(items, key, value.Field(i))
	}
	return items.Encode()
}

func addQueryStringValue(items url.Values, key string, v reflect.Value) {
	switch v.Kind() {
	case reflect.Bool:
		if v.Bool() {
			items.Add(key, "1")
		}
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		if v.Int() > 0 {
			items.Add(key, strconv.FormatInt(v.Int(), 10))
		}
	case reflect.Float32, reflect.Float64:
		if v.Float() > 0 {
			items.Add(key, strconv.FormatFloat(v.Float(), 'f', -1, 64))
		}
	case reflect.String:
		if v.String() != "" {
			items.Add(key, v.String())
		}
	case reflect.Ptr:
		if !v.IsNil() {
			if b, err := json.Marshal(v.Interface()); err == nil {
				items.Add(key, string(b))
			}
		}
	case reflect.Map:
		if len(v.MapKeys()) > 0 {
			if b, err := json.Marshal(v.Interface()); err == nil {
				items.Add(key, string(b))
			}
		}
	case reflect.Array, reflect.Slice:
		vLen := v.Len()
		if vLen > 0 {
			for i := 0; i < vLen; i++ {
				addQueryStringValue(items, key, v.Index(i))
			}
		}
	}
}

// Error represents failures in the API. It represents a failure from the API.
type Error struct {
	Status  int
	Message string
}

func newError(resp *http.Response) *Error {
	type jsonError struct {
		Message string `json:"message"`
	}

	defer resp.Body.Close()
	data, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return &Error{Status: resp.StatusCode, Message: fmt.Sprintf("cannot read body, err: %v", err)}
	}

	// attempt to unmarshal the error if in json format
	jerr := &jsonError{}
	err = json.Unmarshal(data, jerr)
	if err != nil {
		return &Error{Status: resp.StatusCode, Message: string(data)} // Failed, just return string
	}

	return &Error{Status: resp.StatusCode, Message: jerr.Message}
}

func (e *Error) Error() string {
	var niceStatus string

	switch e.Status {
	case 400, 500:
		niceStatus = "Server failed to process your request. Was the data correct?"
	case 401:
		niceStatus = "Unauthenticated access of secure endpoint, please retry after authentication"
	case 403:
		niceStatus = "Forbidden request. Your user cannot perform this action"
	case 404:
		niceStatus = "Requested object not found. Does this item exist?"
	}

	if niceStatus != "" {
		return fmt.Sprintf("API error (%s): %s", niceStatus, e.Message)
	}
	return fmt.Sprintf("API error (%s): %s", http.StatusText(e.Status), e.Message)
}

// defaultPooledTransport returns a new http.Transport with similar default
// values to http.DefaultTransport. Do not use this for transient transports as
// it can leak file descriptors over time. Only use this for transports that
// will be re-used for the same host(s).
func defaultPooledTransport(dialer Dialer) *http.Transport {
	transport := &http.Transport{
		Proxy:               http.ProxyFromEnvironment,
		Dial:                dialer.Dial,
		TLSHandshakeTimeout: 5 * time.Second,
		DisableKeepAlives:   false,
		MaxIdleConnsPerHost: 1,
	}
	return transport
}

// defaultClient returns a new http.Client with similar default values to
// http.Client, but with a non-shared Transport, idle connections disabled, and
// keepalives disabled.
// If a custom dialer is not provided, one with sane defaults will be created.
func defaultClient() *http.Client {
	dialer := &net.Dialer{
		Timeout:   5 * time.Second,
		KeepAlive: 5 * time.Second,
	}

	return &http.Client{
		Transport: defaultPooledTransport(dialer),
	}
}