254 lines
4.7 KiB
Go
254 lines
4.7 KiB
Go
// +build cgo
|
|
|
|
package dinet
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
zmq "github.com/pebbe/zmq4"
|
|
|
|
"src.dualinventive.com/go/dinet/rpc"
|
|
)
|
|
|
|
// ZmqTransportTimeoutInfinite is used for infinite send and receive wait timeouts (default value)
|
|
const ZmqTransportTimeoutInfinite = -1 * time.Second
|
|
|
|
// ZmqTransport is an encoder to encode/decode rpc messages over a zmq tcp socket
|
|
type ZmqTransport struct {
|
|
mu sync.RWMutex
|
|
|
|
host string
|
|
ctx *zmq.Context
|
|
sock *zmq.Socket
|
|
timeout time.Duration
|
|
}
|
|
|
|
// Connect connects to the specified host
|
|
func (zmqe *ZmqTransport) Connect(host string) error {
|
|
zmqe.mu.Lock()
|
|
defer zmqe.mu.Unlock()
|
|
return zmqe.connect(host)
|
|
}
|
|
|
|
func (zmqe *ZmqTransport) connect(host string) error {
|
|
if host != "" {
|
|
zmqe.host = host
|
|
}
|
|
|
|
hostURL, err := url.Parse(zmqe.host)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if zmqe.ctx == nil {
|
|
ctx, err := zmq.NewContext()
|
|
if err != nil {
|
|
return errors.New("zmq.NewContext failed")
|
|
}
|
|
zmqe.ctx = ctx
|
|
}
|
|
|
|
return zmqe.attachZmqSocket(hostURL)
|
|
}
|
|
|
|
func (zmqe *ZmqTransport) terminate() {
|
|
if zmqe.ctx != nil {
|
|
ret := zmqe.ctx.Term()
|
|
_ = ret
|
|
zmqe.ctx = nil
|
|
}
|
|
}
|
|
func (zmqe *ZmqTransport) close() error {
|
|
if zmqe.sock != nil {
|
|
err := zmqe.sock.Close()
|
|
_ = err
|
|
zmqe.sock = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Close disconnects the socket
|
|
func (zmqe *ZmqTransport) Close() error {
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
zmqe.terminate()
|
|
wg.Done()
|
|
}()
|
|
|
|
zmqe.mu.Lock()
|
|
err := zmqe.close()
|
|
zmqe.mu.Unlock()
|
|
wg.Wait()
|
|
return err
|
|
}
|
|
|
|
// Reconnect closes the socket and reconnects again
|
|
func (zmqe *ZmqTransport) Reconnect() error {
|
|
zmqe.mu.Lock()
|
|
defer zmqe.mu.Unlock()
|
|
|
|
if zmqe.sock == nil {
|
|
return ErrClosed
|
|
}
|
|
|
|
if err := zmqe.close(); err != nil {
|
|
return err
|
|
}
|
|
if err := zmqe.connect(""); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetTimeout sets the time to wait for incoming messages
|
|
func (zmqe *ZmqTransport) SetTimeout(timeout time.Duration) error {
|
|
zmqe.mu.Lock()
|
|
zmqe.timeout = timeout
|
|
zmqe.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// Send encodes a rpc message and sends it over the ZMQ socket
|
|
func (zmqe *ZmqTransport) Send(m *rpc.Msg) error {
|
|
zmqe.mu.RLock()
|
|
defer zmqe.mu.RUnlock()
|
|
|
|
if zmqe.sock == nil {
|
|
return ErrClosed
|
|
}
|
|
|
|
data, err := json.Marshal(m)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = zmqe.sock.SendBytes(data, 0)
|
|
if err == zmq.ETERM {
|
|
return ErrClosed
|
|
}
|
|
|
|
// When an EFSM occurs "Operation cannot be accomplished in current state" we reconnect and retry
|
|
if err == zmq.EFSM {
|
|
zmqe.mu.RUnlock()
|
|
err = zmqe.Reconnect()
|
|
zmqe.mu.RLock()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = zmqe.sock.SendBytes(data, 0)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Recv waits until it gets a rpc message from the tcp socket and decodes the message
|
|
func (zmqe *ZmqTransport) Recv() (*rpc.Msg, error) {
|
|
zmqe.mu.RLock()
|
|
defer zmqe.mu.RUnlock()
|
|
|
|
if zmqe.sock == nil {
|
|
return nil, ErrClosed
|
|
}
|
|
|
|
// Create a poller so we can wait for incomming messages
|
|
poller := zmq.NewPoller()
|
|
// Suppress linter, because zmq.POLLIN is a cgo variable which gotype doesn't support
|
|
//nolint: gotype
|
|
poller.Add(zmqe.sock, zmq.POLLIN)
|
|
|
|
sockets, err := poller.Poll(zmqe.timeout)
|
|
if err != nil {
|
|
if err == zmq.ETERM {
|
|
err = ErrClosed
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if len(sockets) == 0 {
|
|
return nil, ErrTimeout
|
|
}
|
|
|
|
var msgBytes []byte
|
|
more := true
|
|
for more {
|
|
msgBytes, err = zmqe.sock.RecvBytes(0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
more, err = zmqe.sock.GetRcvmore()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
msg := &rpc.Msg{}
|
|
if err := json.Unmarshal(msgBytes, msg); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := msg.Valid(); err != nil {
|
|
return nil, err
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
func (zmqe *ZmqTransport) getSocketType(t string) (zmq.Type, error) {
|
|
switch t {
|
|
case "req":
|
|
return zmq.REQ, nil
|
|
case "rep":
|
|
return zmq.REP, nil
|
|
case "sub":
|
|
return zmq.SUB, nil
|
|
case "pub":
|
|
return zmq.PUB, nil
|
|
}
|
|
|
|
return zmq.Type(0), ErrType
|
|
}
|
|
|
|
// attachZmqSocket attaches and configures the socket based on the url
|
|
func (zmqe *ZmqTransport) attachZmqSocket(url *url.URL) error {
|
|
var err error
|
|
|
|
hostQuery := url.Query()
|
|
zmqType := hostQuery.Get("type")
|
|
|
|
t, err := zmqe.getSocketType(zmqType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
zmqe.sock, err = zmqe.ctx.NewSocket(t)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if errlin := zmqe.sock.SetLinger(0); errlin != nil {
|
|
return errlin
|
|
}
|
|
// All query values are parsed. Remove the query
|
|
url.RawQuery = ""
|
|
|
|
if hostQuery.Get("bind") == "true" {
|
|
if err = zmqe.sock.Bind(url.String()); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err = zmqe.sock.Connect(url.String()); err != nil {
|
|
return ErrDisconnected
|
|
}
|
|
}
|
|
|
|
if zmqType == "sub" {
|
|
if errsub := zmqe.sock.SetSubscribe(""); errsub != nil {
|
|
return errsub
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|