src.dualinventive.com/go/dinet/router.go

204 lines
5.0 KiB
Go

package dinet
import (
"sync"
"src.dualinventive.com/go/dinet/rpc"
"src.dualinventive.com/go/lib/dilog"
)
// RPCCallback is the function type to execute when the router fires the callback
type RPCCallback func(*rpc.Msg) *rpc.Msg
// Router routes messages to callbacks, consumers can subscribe using a connection
type Router struct {
logger dilog.Logger
con ReadWriter
routerAny rpcRouter
routerPub rpcRouter
routerRep rpcRouter
routerReq rpcRouter
}
// NewRouter returns a new router based on the passed connection
func NewRouter(logger dilog.Logger, con ReadWriter) (*Router, error) {
r := &Router{logger: logger, con: con}
return r, nil
}
// Close closes the parent connection
func (c *Router) Close() error {
if con, ok := c.con.(Closer); ok {
return con.Close()
}
return nil
}
// ListenAndServe listens to incomming rpc messages and fires the callbacks when it satisfy the callback
func (c *Router) ListenAndServe() error {
for {
msg, err := c.con.Recv()
if err != nil {
if msgError, ok := err.(*rpc.InvalidMsgError); ok {
c.logger.WithError(msgError).WithField("msg", msgError.Msg).Warning("Invalid RPC message")
if msgError.Msg.Type == rpc.MsgTypeRequest {
rep := msgError.Msg.CreateReply()
rep.SetError(rpc.EProto)
err = c.con.Send(rep)
if err != nil {
c.logger.WithError(err).Error("Router.ListenAndServe send failed")
return err
}
}
} else {
return err
}
} else if msg != nil {
// msg can be nil when no errors are there but an non-rpc-msg is received
c.handlePlainMessage(msg)
}
}
}
func (c *Router) handlePlainMessage(rpcmsg *rpc.Msg) {
var replies []*rpc.Msg
switch rpcmsg.Type {
case rpc.MsgTypePublish:
replies = c.routerPub.fire(rpcmsg)
replies = append(replies, c.routerAny.fire(rpcmsg)...)
case rpc.MsgTypeRequest:
replies = c.routerReq.fire(rpcmsg)
replies = append(replies, c.routerAny.fire(rpcmsg)...)
if len(replies) == 0 {
c.logger.WithField("classmethod", rpcmsg.ClassMethod).Debug("No reply for request")
replies = append(replies, rpcmsg.CreateReply().SetError(rpc.EOpnotsupp))
}
case rpc.MsgTypeReply:
replies = c.routerRep.fire(rpcmsg)
replies = append(replies, c.routerAny.fire(rpcmsg)...)
}
for _, reply := range replies {
reply.DeviceUID = rpcmsg.DeviceUID
reply.UserID = rpcmsg.UserID
reply.ProjectID = rpcmsg.ProjectID
if err := c.con.Send(reply); err != nil {
c.logger.WithError(err).Debug("error while sending reply")
}
}
}
// Subscribe subscribes for incoming messages for a Device UID, function will panic if the message type is invalid
func (c *Router) Subscribe(uid string, classMethod rpc.ClassMethod, t rpc.MsgType, cb RPCCallback) {
switch t {
case rpc.MsgTypePublish:
c.routerPub.attach(uid, classMethod, cb)
return
case rpc.MsgTypeRequest:
c.routerReq.attach(uid, classMethod, cb)
return
case rpc.MsgTypeReply:
c.routerRep.attach(uid, classMethod, cb)
return
case rpc.MsgTypeAny:
c.routerAny.attach(uid, classMethod, cb)
return
}
panic("unknown message type")
}
// UnsubscribeDevice unsubscribes all callbacks for a specific device:uid
func (c *Router) UnsubscribeDevice(uid string) {
c.routerReq.detachDevice(uid)
c.routerRep.detachDevice(uid)
c.routerPub.detachDevice(uid)
c.routerAny.detachDevice(uid)
}
// Send transfers the message to the connection
func (c *Router) Send(msg *rpc.Msg) error {
return c.con.Send(msg)
}
type rpcRouter struct {
lock sync.RWMutex
callbacks map[string]map[rpc.ClassMethod][]RPCCallback
}
func (r *rpcRouter) detachDevice(uid string) {
r.lock.Lock()
defer r.lock.Unlock()
if r.callbacks[uid] != nil {
r.callbacks[uid] = make(map[rpc.ClassMethod][]RPCCallback)
}
}
func (r *rpcRouter) attach(uid string, classMethod rpc.ClassMethod, cb RPCCallback) {
r.lock.Lock()
defer r.lock.Unlock()
if uid == "*" {
uid = ""
}
if classMethod == "*" {
classMethod = ""
}
if r.callbacks == nil {
r.callbacks = make(map[string]map[rpc.ClassMethod][]RPCCallback)
}
if r.callbacks[uid] == nil {
r.callbacks[uid] = make(map[rpc.ClassMethod][]RPCCallback)
}
r.callbacks[uid][classMethod] = append(r.callbacks[uid][classMethod], cb)
}
func (r *rpcRouter) fireCallbacks(msg *rpc.Msg, classMethodCallbacks map[rpc.ClassMethod][]RPCCallback) []*rpc.Msg {
var msgs []*rpc.Msg
// wildcard
if wildCbList, ok := classMethodCallbacks[""]; ok {
for _, cb := range wildCbList {
m := cb(msg)
if m != nil {
msgs = append(msgs, m)
}
}
}
// specific
if cbList, ok := classMethodCallbacks[msg.ClassMethod]; ok {
for _, cb := range cbList {
m := cb(msg)
if m != nil {
msgs = append(msgs, m)
}
}
}
return msgs
}
func (r *rpcRouter) fire(msg *rpc.Msg) []*rpc.Msg {
var msgs []*rpc.Msg
r.lock.RLock()
defer r.lock.RUnlock()
if wlcDevs, ok := r.callbacks[""]; ok {
msgs = append(msgs, r.fireCallbacks(msg, wlcDevs)...)
}
if len(msg.DeviceUID) == 0 {
return msgs
}
if directDevs, ok := r.callbacks[msg.DeviceUID]; ok {
msgs = append(msgs, r.fireCallbacks(msg, directDevs)...)
}
return msgs
}