204 lines
5.0 KiB
Go
204 lines
5.0 KiB
Go
package dinet
|
|
|
|
import (
|
|
"sync"
|
|
|
|
"src.dualinventive.com/go/dinet/rpc"
|
|
"src.dualinventive.com/go/lib/dilog"
|
|
)
|
|
|
|
// RPCCallback is the function type to execute when the router fires the callback
|
|
type RPCCallback func(*rpc.Msg) *rpc.Msg
|
|
|
|
// Router routes messages to callbacks, consumers can subscribe using a connection
|
|
type Router struct {
|
|
logger dilog.Logger
|
|
con ReadWriter
|
|
routerAny rpcRouter
|
|
routerPub rpcRouter
|
|
routerRep rpcRouter
|
|
routerReq rpcRouter
|
|
}
|
|
|
|
// NewRouter returns a new router based on the passed connection
|
|
func NewRouter(logger dilog.Logger, con ReadWriter) (*Router, error) {
|
|
r := &Router{logger: logger, con: con}
|
|
return r, nil
|
|
}
|
|
|
|
// Close closes the parent connection
|
|
func (c *Router) Close() error {
|
|
if con, ok := c.con.(Closer); ok {
|
|
return con.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ListenAndServe listens to incomming rpc messages and fires the callbacks when it satisfy the callback
|
|
func (c *Router) ListenAndServe() error {
|
|
for {
|
|
msg, err := c.con.Recv()
|
|
if err != nil {
|
|
if msgError, ok := err.(*rpc.InvalidMsgError); ok {
|
|
c.logger.WithError(msgError).WithField("msg", msgError.Msg).Warning("Invalid RPC message")
|
|
if msgError.Msg.Type == rpc.MsgTypeRequest {
|
|
rep := msgError.Msg.CreateReply()
|
|
rep.SetError(rpc.EProto)
|
|
err = c.con.Send(rep)
|
|
if err != nil {
|
|
c.logger.WithError(err).Error("Router.ListenAndServe send failed")
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
} else if msg != nil {
|
|
// msg can be nil when no errors are there but an non-rpc-msg is received
|
|
c.handlePlainMessage(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Router) handlePlainMessage(rpcmsg *rpc.Msg) {
|
|
var replies []*rpc.Msg
|
|
switch rpcmsg.Type {
|
|
case rpc.MsgTypePublish:
|
|
replies = c.routerPub.fire(rpcmsg)
|
|
replies = append(replies, c.routerAny.fire(rpcmsg)...)
|
|
case rpc.MsgTypeRequest:
|
|
replies = c.routerReq.fire(rpcmsg)
|
|
replies = append(replies, c.routerAny.fire(rpcmsg)...)
|
|
if len(replies) == 0 {
|
|
c.logger.WithField("classmethod", rpcmsg.ClassMethod).Debug("No reply for request")
|
|
replies = append(replies, rpcmsg.CreateReply().SetError(rpc.EOpnotsupp))
|
|
}
|
|
case rpc.MsgTypeReply:
|
|
replies = c.routerRep.fire(rpcmsg)
|
|
replies = append(replies, c.routerAny.fire(rpcmsg)...)
|
|
}
|
|
for _, reply := range replies {
|
|
reply.DeviceUID = rpcmsg.DeviceUID
|
|
reply.UserID = rpcmsg.UserID
|
|
reply.ProjectID = rpcmsg.ProjectID
|
|
if err := c.con.Send(reply); err != nil {
|
|
c.logger.WithError(err).Debug("error while sending reply")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Subscribe subscribes for incoming messages for a Device UID, function will panic if the message type is invalid
|
|
func (c *Router) Subscribe(uid string, classMethod rpc.ClassMethod, t rpc.MsgType, cb RPCCallback) {
|
|
switch t {
|
|
case rpc.MsgTypePublish:
|
|
c.routerPub.attach(uid, classMethod, cb)
|
|
return
|
|
case rpc.MsgTypeRequest:
|
|
c.routerReq.attach(uid, classMethod, cb)
|
|
return
|
|
case rpc.MsgTypeReply:
|
|
c.routerRep.attach(uid, classMethod, cb)
|
|
return
|
|
case rpc.MsgTypeAny:
|
|
c.routerAny.attach(uid, classMethod, cb)
|
|
return
|
|
}
|
|
panic("unknown message type")
|
|
}
|
|
|
|
// UnsubscribeDevice unsubscribes all callbacks for a specific device:uid
|
|
func (c *Router) UnsubscribeDevice(uid string) {
|
|
c.routerReq.detachDevice(uid)
|
|
c.routerRep.detachDevice(uid)
|
|
c.routerPub.detachDevice(uid)
|
|
c.routerAny.detachDevice(uid)
|
|
}
|
|
|
|
// Send transfers the message to the connection
|
|
func (c *Router) Send(msg *rpc.Msg) error {
|
|
return c.con.Send(msg)
|
|
}
|
|
|
|
type rpcRouter struct {
|
|
lock sync.RWMutex
|
|
callbacks map[string]map[rpc.ClassMethod][]RPCCallback
|
|
}
|
|
|
|
func (r *rpcRouter) detachDevice(uid string) {
|
|
r.lock.Lock()
|
|
defer r.lock.Unlock()
|
|
|
|
if r.callbacks[uid] != nil {
|
|
r.callbacks[uid] = make(map[rpc.ClassMethod][]RPCCallback)
|
|
}
|
|
}
|
|
|
|
func (r *rpcRouter) attach(uid string, classMethod rpc.ClassMethod, cb RPCCallback) {
|
|
r.lock.Lock()
|
|
defer r.lock.Unlock()
|
|
|
|
if uid == "*" {
|
|
uid = ""
|
|
}
|
|
|
|
if classMethod == "*" {
|
|
classMethod = ""
|
|
}
|
|
|
|
if r.callbacks == nil {
|
|
r.callbacks = make(map[string]map[rpc.ClassMethod][]RPCCallback)
|
|
}
|
|
if r.callbacks[uid] == nil {
|
|
r.callbacks[uid] = make(map[rpc.ClassMethod][]RPCCallback)
|
|
}
|
|
|
|
r.callbacks[uid][classMethod] = append(r.callbacks[uid][classMethod], cb)
|
|
}
|
|
|
|
func (r *rpcRouter) fireCallbacks(msg *rpc.Msg, classMethodCallbacks map[rpc.ClassMethod][]RPCCallback) []*rpc.Msg {
|
|
var msgs []*rpc.Msg
|
|
|
|
// wildcard
|
|
if wildCbList, ok := classMethodCallbacks[""]; ok {
|
|
for _, cb := range wildCbList {
|
|
m := cb(msg)
|
|
if m != nil {
|
|
msgs = append(msgs, m)
|
|
}
|
|
}
|
|
}
|
|
|
|
// specific
|
|
if cbList, ok := classMethodCallbacks[msg.ClassMethod]; ok {
|
|
for _, cb := range cbList {
|
|
m := cb(msg)
|
|
if m != nil {
|
|
msgs = append(msgs, m)
|
|
}
|
|
}
|
|
}
|
|
|
|
return msgs
|
|
}
|
|
|
|
func (r *rpcRouter) fire(msg *rpc.Msg) []*rpc.Msg {
|
|
var msgs []*rpc.Msg
|
|
|
|
r.lock.RLock()
|
|
defer r.lock.RUnlock()
|
|
|
|
if wlcDevs, ok := r.callbacks[""]; ok {
|
|
msgs = append(msgs, r.fireCallbacks(msg, wlcDevs)...)
|
|
}
|
|
|
|
if len(msg.DeviceUID) == 0 {
|
|
return msgs
|
|
}
|
|
|
|
if directDevs, ok := r.callbacks[msg.DeviceUID]; ok {
|
|
msgs = append(msgs, r.fireCallbacks(msg, directDevs)...)
|
|
}
|
|
|
|
return msgs
|
|
}
|