package wsconn import ( "context" "encoding/json" "io" "sync" "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "src.dualinventive.com/go/dinet/rpc" "src.dualinventive.com/go/websocketserver/internal/mtiwss" ) // client keeps track a single websocket connection and subscription fields type client struct { isClosed bool con *websocket.Conn wg sync.WaitGroup cancel func() mu sync.RWMutex rchan chan []byte // rchan is used to read messages from the websocket wchan chan []byte // wchan is used to write messages to the websocket deviceFields map[string][]string // device:uid [field1, field2] projectFields map[uint64][]string // project:id [field1, field2] } // filterFields filters values which are not present in the fields slice func filterFields(values map[string]string, fields []string) map[string]string { cvals := make(map[string]string) for k, v := range values { for _, f := range fields { if k == f { cvals[k] = v } } } return cvals } // newClient creates a new websocket client func newClient(c *websocket.Conn) *client { wsc := &client{con: c} wsc.rchan = make(chan []byte) wsc.wchan = make(chan []byte) return wsc } // Publish a raw message over the websocket func (c *client) Send(msg []byte) (err error) { c.mu.RLock() if c.isClosed { c.mu.RUnlock() return io.EOF } select { case c.wchan <- msg: default: } c.mu.RUnlock() return nil } // Close closes the websocket connection func (c *client) Close() error { c.mu.Lock() if c.isClosed { c.mu.Unlock() return io.EOF } c.isClosed = true close(c.wchan) c.wchan = nil c.cancel() err := c.con.Close() c.wg.Wait() c.con = nil c.mu.Unlock() return err } // PublishJSON marshals and send v to the client func (c *client) PublishJSON(v interface{}) error { msg, err := json.Marshal(v) if err != nil { return err } return c.Send(msg) } // PublishDeviceUIDUpdate publishes a filtered update for the DeviceUID based on the client // subscribed fields func (c *client) PublishDeviceUIDUpdate(DeviceUID string, values map[string]string) { c.mu.RLock() defer c.mu.RUnlock() // Get the subscribed fields for the DeviceUID fields, ok := c.deviceFields[DeviceUID] if !ok { return } // Filter the values based on the connection subscription fields fvalues := filterFields(values, fields) if len(fvalues) == 0 { return } // Prepare the DI-Net RPC realtime:data publish message for sending msg := &rpc.Msg{ DeviceUID: DeviceUID, Type: rpc.MsgTypePublish, ClassMethod: rpc.ClassMethodRealtimeData, Result: jsonMapString(fvalues), } if err := c.PublishJSON(msg); err != nil { logrus.Warnf("Publish JSON failed: %v", err) } } // PublishProjectIDUpdate publishes a filtered update for the ProjectID based on the client // subscribed fields func (c *client) PublishProjectIDUpdate(ProjectID uint64, values map[string]string) { c.mu.RLock() defer c.mu.RUnlock() // Get the subscribed fields for the ProjectID fields, ok := c.projectFields[ProjectID] if !ok { return } // Filter the values based on the connection subscription fields fvalues := filterFields(values, fields) if len(fvalues) == 0 { return } // Prepare the DI-Net RPC realtime:data publish message for sending msg := &rpc.Msg{ ProjectID: uint(ProjectID), Type: rpc.MsgTypePublish, ClassMethod: rpc.ClassMethodRealtimeData, Result: jsonMapString(fvalues), } if err := c.PublishJSON(msg); err != nil { logrus.Warnf("Publish JSON failed: %v", err) } } // serveRead reads from the websocket and writes to c.rchan func (c *client) serveRead() { defer c.wg.Done() for { _, msg, err := c.con.ReadMessage() if err != nil { close(c.rchan) logrus.Warnf("error reading message: %v", err) return } c.rchan <- msg } } // Serve serves the client messages and dispatches requests to mtiwss func (c *client) Serve(ctx context.Context, mtiwss *mtiwss.Mtiwss) { // Start read worker c.wg.Add(1) go c.serveRead() for { select { case <-ctx.Done(): return case msg := <-c.wchan: err := c.con.WriteMessage(websocket.TextMessage, msg) if err != nil { logrus.Warnf("error writing message: %v", err) return } case msg, ok := <-c.rchan: if !ok { // When rchan fails, it comes here to exit from serve c.rchan = nil return } if ok, err := c.mtiWssDecodeRequest(msg); ok || err != nil { logrus.Warnf("error decoding wss request: %v", err) continue } c.wg.Add(1) go func() { defer c.wg.Done() rep, err := mtiwss.Request(msg) if err != nil { logrus.Warnf("error requesting to mtiwss: %v", err) return } c.mtiWssDecodeReply(rep) }() } } }