device_plugin_stub.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /*
  2. Copyright 2017 The Kubernetes Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package devicemanager
  14. import (
  15. "context"
  16. "log"
  17. "net"
  18. "os"
  19. "path"
  20. "sync"
  21. "time"
  22. "google.golang.org/grpc"
  23. pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1"
  24. watcherapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1"
  25. )
  26. // Stub implementation for DevicePlugin.
  27. type Stub struct {
  28. devs []*pluginapi.Device
  29. socket string
  30. resourceName string
  31. preStartContainerFlag bool
  32. stop chan interface{}
  33. wg sync.WaitGroup
  34. update chan []*pluginapi.Device
  35. server *grpc.Server
  36. // allocFunc is used for handling allocation request
  37. allocFunc stubAllocFunc
  38. registrationStatus chan watcherapi.RegistrationStatus // for testing
  39. endpoint string // for testing
  40. }
  41. // stubAllocFunc is the function called when receive an allocation request from Kubelet
  42. type stubAllocFunc func(r *pluginapi.AllocateRequest, devs map[string]pluginapi.Device) (*pluginapi.AllocateResponse, error)
  43. func defaultAllocFunc(r *pluginapi.AllocateRequest, devs map[string]pluginapi.Device) (*pluginapi.AllocateResponse, error) {
  44. var response pluginapi.AllocateResponse
  45. return &response, nil
  46. }
  47. // NewDevicePluginStub returns an initialized DevicePlugin Stub.
  48. func NewDevicePluginStub(devs []*pluginapi.Device, socket string, name string, preStartContainerFlag bool) *Stub {
  49. return &Stub{
  50. devs: devs,
  51. socket: socket,
  52. resourceName: name,
  53. preStartContainerFlag: preStartContainerFlag,
  54. stop: make(chan interface{}),
  55. update: make(chan []*pluginapi.Device),
  56. allocFunc: defaultAllocFunc,
  57. }
  58. }
  59. // SetAllocFunc sets allocFunc of the device plugin
  60. func (m *Stub) SetAllocFunc(f stubAllocFunc) {
  61. m.allocFunc = f
  62. }
  63. // Start starts the gRPC server of the device plugin. Can only
  64. // be called once.
  65. func (m *Stub) Start() error {
  66. err := m.cleanup()
  67. if err != nil {
  68. return err
  69. }
  70. sock, err := net.Listen("unix", m.socket)
  71. if err != nil {
  72. return err
  73. }
  74. m.wg.Add(1)
  75. m.server = grpc.NewServer([]grpc.ServerOption{}...)
  76. pluginapi.RegisterDevicePluginServer(m.server, m)
  77. watcherapi.RegisterRegistrationServer(m.server, m)
  78. go func() {
  79. defer m.wg.Done()
  80. m.server.Serve(sock)
  81. }()
  82. _, conn, err := dial(m.socket)
  83. if err != nil {
  84. return err
  85. }
  86. conn.Close()
  87. log.Println("Starting to serve on", m.socket)
  88. return nil
  89. }
  90. // Stop stops the gRPC server. Can be called without a prior Start
  91. // and more than once. Not safe to be called concurrently by different
  92. // goroutines!
  93. func (m *Stub) Stop() error {
  94. if m.server == nil {
  95. return nil
  96. }
  97. m.server.Stop()
  98. m.wg.Wait()
  99. m.server = nil
  100. close(m.stop) // This prevents re-starting the server.
  101. return m.cleanup()
  102. }
  103. // GetInfo is the RPC which return pluginInfo
  104. func (m *Stub) GetInfo(ctx context.Context, req *watcherapi.InfoRequest) (*watcherapi.PluginInfo, error) {
  105. log.Println("GetInfo")
  106. return &watcherapi.PluginInfo{
  107. Type: watcherapi.DevicePlugin,
  108. Name: m.resourceName,
  109. Endpoint: m.endpoint,
  110. SupportedVersions: []string{pluginapi.Version}}, nil
  111. }
  112. // NotifyRegistrationStatus receives the registration notification from watcher
  113. func (m *Stub) NotifyRegistrationStatus(ctx context.Context, status *watcherapi.RegistrationStatus) (*watcherapi.RegistrationStatusResponse, error) {
  114. if m.registrationStatus != nil {
  115. m.registrationStatus <- *status
  116. }
  117. if !status.PluginRegistered {
  118. log.Println("Registration failed: ", status.Error)
  119. }
  120. return &watcherapi.RegistrationStatusResponse{}, nil
  121. }
  122. // Register registers the device plugin for the given resourceName with Kubelet.
  123. func (m *Stub) Register(kubeletEndpoint, resourceName string, pluginSockDir string) error {
  124. if pluginSockDir != "" {
  125. if _, err := os.Stat(pluginSockDir + "DEPRECATION"); err == nil {
  126. log.Println("Deprecation file found. Skip registration.")
  127. return nil
  128. }
  129. }
  130. log.Println("Deprecation file not found. Invoke registration")
  131. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  132. defer cancel()
  133. conn, err := grpc.DialContext(ctx, kubeletEndpoint, grpc.WithInsecure(), grpc.WithBlock(),
  134. grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
  135. return net.DialTimeout("unix", addr, timeout)
  136. }))
  137. if err != nil {
  138. return err
  139. }
  140. defer conn.Close()
  141. client := pluginapi.NewRegistrationClient(conn)
  142. reqt := &pluginapi.RegisterRequest{
  143. Version: pluginapi.Version,
  144. Endpoint: path.Base(m.socket),
  145. ResourceName: resourceName,
  146. Options: &pluginapi.DevicePluginOptions{PreStartRequired: m.preStartContainerFlag},
  147. }
  148. _, err = client.Register(context.Background(), reqt)
  149. if err != nil {
  150. return err
  151. }
  152. return nil
  153. }
  154. // GetDevicePluginOptions returns DevicePluginOptions settings for the device plugin.
  155. func (m *Stub) GetDevicePluginOptions(ctx context.Context, e *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
  156. return &pluginapi.DevicePluginOptions{PreStartRequired: m.preStartContainerFlag}, nil
  157. }
  158. // PreStartContainer resets the devices received
  159. func (m *Stub) PreStartContainer(ctx context.Context, r *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) {
  160. log.Printf("PreStartContainer, %+v", r)
  161. return &pluginapi.PreStartContainerResponse{}, nil
  162. }
  163. // ListAndWatch lists devices and update that list according to the Update call
  164. func (m *Stub) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
  165. log.Println("ListAndWatch")
  166. s.Send(&pluginapi.ListAndWatchResponse{Devices: m.devs})
  167. for {
  168. select {
  169. case <-m.stop:
  170. return nil
  171. case updated := <-m.update:
  172. s.Send(&pluginapi.ListAndWatchResponse{Devices: updated})
  173. }
  174. }
  175. }
  176. // Update allows the device plugin to send new devices through ListAndWatch
  177. func (m *Stub) Update(devs []*pluginapi.Device) {
  178. m.update <- devs
  179. }
  180. // Allocate does a mock allocation
  181. func (m *Stub) Allocate(ctx context.Context, r *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
  182. log.Printf("Allocate, %+v", r)
  183. devs := make(map[string]pluginapi.Device)
  184. for _, dev := range m.devs {
  185. devs[dev.ID] = *dev
  186. }
  187. return m.allocFunc(r, devs)
  188. }
  189. func (m *Stub) cleanup() error {
  190. if err := os.Remove(m.socket); err != nil && !os.IsNotExist(err) {
  191. return err
  192. }
  193. return nil
  194. }