src.dualinventive.com/go/dinet/transport_zmq_linux.go

254 lines
4.7 KiB
Go

// +build cgo
package dinet
import (
"encoding/json"
"errors"
"net/url"
"sync"
"time"
zmq "github.com/pebbe/zmq4"
"src.dualinventive.com/go/dinet/rpc"
)
// ZmqTransportTimeoutInfinite is used for infinite send and receive wait timeouts (default value)
const ZmqTransportTimeoutInfinite = -1 * time.Second
// ZmqTransport is an encoder to encode/decode rpc messages over a zmq tcp socket
type ZmqTransport struct {
mu sync.RWMutex
host string
ctx *zmq.Context
sock *zmq.Socket
timeout time.Duration
}
// Connect connects to the specified host
func (zmqe *ZmqTransport) Connect(host string) error {
zmqe.mu.Lock()
defer zmqe.mu.Unlock()
return zmqe.connect(host)
}
func (zmqe *ZmqTransport) connect(host string) error {
if host != "" {
zmqe.host = host
}
hostURL, err := url.Parse(zmqe.host)
if err != nil {
return err
}
if zmqe.ctx == nil {
ctx, err := zmq.NewContext()
if err != nil {
return errors.New("zmq.NewContext failed")
}
zmqe.ctx = ctx
}
return zmqe.attachZmqSocket(hostURL)
}
func (zmqe *ZmqTransport) terminate() {
if zmqe.ctx != nil {
ret := zmqe.ctx.Term()
_ = ret
zmqe.ctx = nil
}
}
func (zmqe *ZmqTransport) close() error {
if zmqe.sock != nil {
err := zmqe.sock.Close()
_ = err
zmqe.sock = nil
}
return nil
}
// Close disconnects the socket
func (zmqe *ZmqTransport) Close() error {
var wg sync.WaitGroup
wg.Add(1)
go func() {
zmqe.terminate()
wg.Done()
}()
zmqe.mu.Lock()
err := zmqe.close()
zmqe.mu.Unlock()
wg.Wait()
return err
}
// Reconnect closes the socket and reconnects again
func (zmqe *ZmqTransport) Reconnect() error {
zmqe.mu.Lock()
defer zmqe.mu.Unlock()
if zmqe.sock == nil {
return ErrClosed
}
if err := zmqe.close(); err != nil {
return err
}
if err := zmqe.connect(""); err != nil {
return err
}
return nil
}
// SetTimeout sets the time to wait for incoming messages
func (zmqe *ZmqTransport) SetTimeout(timeout time.Duration) error {
zmqe.mu.Lock()
zmqe.timeout = timeout
zmqe.mu.Unlock()
return nil
}
// Send encodes a rpc message and sends it over the ZMQ socket
func (zmqe *ZmqTransport) Send(m *rpc.Msg) error {
zmqe.mu.RLock()
defer zmqe.mu.RUnlock()
if zmqe.sock == nil {
return ErrClosed
}
data, err := json.Marshal(m)
if err != nil {
return err
}
_, err = zmqe.sock.SendBytes(data, 0)
if err == zmq.ETERM {
return ErrClosed
}
// When an EFSM occurs "Operation cannot be accomplished in current state" we reconnect and retry
if err == zmq.EFSM {
zmqe.mu.RUnlock()
err = zmqe.Reconnect()
zmqe.mu.RLock()
if err != nil {
return err
}
_, err = zmqe.sock.SendBytes(data, 0)
}
return err
}
// Recv waits until it gets a rpc message from the tcp socket and decodes the message
func (zmqe *ZmqTransport) Recv() (*rpc.Msg, error) {
zmqe.mu.RLock()
defer zmqe.mu.RUnlock()
if zmqe.sock == nil {
return nil, ErrClosed
}
// Create a poller so we can wait for incomming messages
poller := zmq.NewPoller()
// Suppress linter, because zmq.POLLIN is a cgo variable which gotype doesn't support
//nolint: gotype
poller.Add(zmqe.sock, zmq.POLLIN)
sockets, err := poller.Poll(zmqe.timeout)
if err != nil {
if err == zmq.ETERM {
err = ErrClosed
}
return nil, err
}
if len(sockets) == 0 {
return nil, ErrTimeout
}
var msgBytes []byte
more := true
for more {
msgBytes, err = zmqe.sock.RecvBytes(0)
if err != nil {
return nil, err
}
more, err = zmqe.sock.GetRcvmore()
if err != nil {
return nil, err
}
}
msg := &rpc.Msg{}
if err := json.Unmarshal(msgBytes, msg); err != nil {
return nil, err
}
if err := msg.Valid(); err != nil {
return nil, err
}
return msg, nil
}
func (zmqe *ZmqTransport) getSocketType(t string) (zmq.Type, error) {
switch t {
case "req":
return zmq.REQ, nil
case "rep":
return zmq.REP, nil
case "sub":
return zmq.SUB, nil
case "pub":
return zmq.PUB, nil
}
return zmq.Type(0), ErrType
}
// attachZmqSocket attaches and configures the socket based on the url
func (zmqe *ZmqTransport) attachZmqSocket(url *url.URL) error {
var err error
hostQuery := url.Query()
zmqType := hostQuery.Get("type")
t, err := zmqe.getSocketType(zmqType)
if err != nil {
return err
}
zmqe.sock, err = zmqe.ctx.NewSocket(t)
if err != nil {
return err
}
if errlin := zmqe.sock.SetLinger(0); errlin != nil {
return errlin
}
// All query values are parsed. Remove the query
url.RawQuery = ""
if hostQuery.Get("bind") == "true" {
if err = zmqe.sock.Bind(url.String()); err != nil {
return err
}
} else {
if err = zmqe.sock.Connect(url.String()); err != nil {
return ErrDisconnected
}
}
if zmqType == "sub" {
if errsub := zmqe.sock.SetSubscribe(""); errsub != nil {
return errsub
}
}
return err
}