package dinet import ( "sync" "src.dualinventive.com/go/dinet/rpc" "src.dualinventive.com/go/lib/dilog" ) // RPCCallback is the function type to execute when the router fires the callback type RPCCallback func(*rpc.Msg) *rpc.Msg // Router routes messages to callbacks, consumers can subscribe using a connection type Router struct { logger dilog.Logger con ReadWriter routerAny rpcRouter routerPub rpcRouter routerRep rpcRouter routerReq rpcRouter } // NewRouter returns a new router based on the passed connection func NewRouter(logger dilog.Logger, con ReadWriter) (*Router, error) { r := &Router{logger: logger, con: con} return r, nil } // Close closes the parent connection func (c *Router) Close() error { if con, ok := c.con.(Closer); ok { return con.Close() } return nil } // ListenAndServe listens to incomming rpc messages and fires the callbacks when it satisfy the callback func (c *Router) ListenAndServe() error { for { msg, err := c.con.Recv() if err != nil { if msgError, ok := err.(*rpc.InvalidMsgError); ok { c.logger.WithError(msgError).WithField("msg", msgError.Msg).Warning("Invalid RPC message") if msgError.Msg.Type == rpc.MsgTypeRequest { rep := msgError.Msg.CreateReply() rep.SetError(rpc.EProto) err = c.con.Send(rep) if err != nil { c.logger.WithError(err).Error("Router.ListenAndServe send failed") return err } } } else { return err } } else if msg != nil { // msg can be nil when no errors are there but an non-rpc-msg is received c.handlePlainMessage(msg) } } } func (c *Router) handlePlainMessage(rpcmsg *rpc.Msg) { var replies []*rpc.Msg switch rpcmsg.Type { case rpc.MsgTypePublish: replies = c.routerPub.fire(rpcmsg) replies = append(replies, c.routerAny.fire(rpcmsg)...) case rpc.MsgTypeRequest: replies = c.routerReq.fire(rpcmsg) replies = append(replies, c.routerAny.fire(rpcmsg)...) if len(replies) == 0 { c.logger.WithField("classmethod", rpcmsg.ClassMethod).Debug("No reply for request") replies = append(replies, rpcmsg.CreateReply().SetError(rpc.EOpnotsupp)) } case rpc.MsgTypeReply: replies = c.routerRep.fire(rpcmsg) replies = append(replies, c.routerAny.fire(rpcmsg)...) } for _, reply := range replies { reply.DeviceUID = rpcmsg.DeviceUID reply.UserID = rpcmsg.UserID reply.ProjectID = rpcmsg.ProjectID if err := c.con.Send(reply); err != nil { c.logger.WithError(err).Debug("error while sending reply") } } } // Subscribe subscribes for incoming messages for a Device UID, function will panic if the message type is invalid func (c *Router) Subscribe(uid string, classMethod rpc.ClassMethod, t rpc.MsgType, cb RPCCallback) { switch t { case rpc.MsgTypePublish: c.routerPub.attach(uid, classMethod, cb) return case rpc.MsgTypeRequest: c.routerReq.attach(uid, classMethod, cb) return case rpc.MsgTypeReply: c.routerRep.attach(uid, classMethod, cb) return case rpc.MsgTypeAny: c.routerAny.attach(uid, classMethod, cb) return } panic("unknown message type") } // UnsubscribeDevice unsubscribes all callbacks for a specific device:uid func (c *Router) UnsubscribeDevice(uid string) { c.routerReq.detachDevice(uid) c.routerRep.detachDevice(uid) c.routerPub.detachDevice(uid) c.routerAny.detachDevice(uid) } // Send transfers the message to the connection func (c *Router) Send(msg *rpc.Msg) error { return c.con.Send(msg) } type rpcRouter struct { lock sync.RWMutex callbacks map[string]map[rpc.ClassMethod][]RPCCallback } func (r *rpcRouter) detachDevice(uid string) { r.lock.Lock() defer r.lock.Unlock() if r.callbacks[uid] != nil { r.callbacks[uid] = make(map[rpc.ClassMethod][]RPCCallback) } } func (r *rpcRouter) attach(uid string, classMethod rpc.ClassMethod, cb RPCCallback) { r.lock.Lock() defer r.lock.Unlock() if uid == "*" { uid = "" } if classMethod == "*" { classMethod = "" } if r.callbacks == nil { r.callbacks = make(map[string]map[rpc.ClassMethod][]RPCCallback) } if r.callbacks[uid] == nil { r.callbacks[uid] = make(map[rpc.ClassMethod][]RPCCallback) } r.callbacks[uid][classMethod] = append(r.callbacks[uid][classMethod], cb) } func (r *rpcRouter) fireCallbacks(msg *rpc.Msg, classMethodCallbacks map[rpc.ClassMethod][]RPCCallback) []*rpc.Msg { var msgs []*rpc.Msg // wildcard if wildCbList, ok := classMethodCallbacks[""]; ok { for _, cb := range wildCbList { m := cb(msg) if m != nil { msgs = append(msgs, m) } } } // specific if cbList, ok := classMethodCallbacks[msg.ClassMethod]; ok { for _, cb := range cbList { m := cb(msg) if m != nil { msgs = append(msgs, m) } } } return msgs } func (r *rpcRouter) fire(msg *rpc.Msg) []*rpc.Msg { var msgs []*rpc.Msg r.lock.RLock() defer r.lock.RUnlock() if wlcDevs, ok := r.callbacks[""]; ok { msgs = append(msgs, r.fireCallbacks(msg, wlcDevs)...) } if len(msg.DeviceUID) == 0 { return msgs } if directDevs, ok := r.callbacks[msg.DeviceUID]; ok { msgs = append(msgs, r.fireCallbacks(msg, directDevs)...) } return msgs }