src.dualinventive.com/go/websocketserver/internal/wsconn/testutil/client.go

109 lines
2.5 KiB
Go

package testutil
import (
"encoding/json"
"fmt"
"net"
"net/url"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"src.dualinventive.com/go/dinet/rpc"
)
// Client is a generic websocket client
type Client struct {
uri string
c *websocket.Conn
}
// NewClient returns a new websocket Client connected to URI
func NewClient(URI string) (*Client, error) {
u, err := url.Parse(URI)
if err != nil {
return nil, err
}
if ok := waitForPort(u.Port(), 10*time.Second); !ok {
return nil, fmt.Errorf("unable to connect to %s", URI)
}
c, _, err := websocket.DefaultDialer.Dial(URI, nil)
if err != nil {
return nil, err
}
return &Client{uri: URI, c: c}, nil
}
// Write writes a single text message to the websocket
func (t *Client) Write(b []byte) error {
return t.c.WriteMessage(websocket.TextMessage, b)
}
// WriteString writes a single text message to the websocket
func (t *Client) WriteString(s string) error {
return t.Write([]byte(s))
}
// Read blocks until a websocket.TextMessage is received other messages are dropped
func (t *Client) Read() ([]byte, error) {
for {
msgType, msg, err := t.c.ReadMessage()
if err != nil {
return nil, err
}
// Make sure we don't pass any non TextMessages to the caller
if msgType != websocket.TextMessage {
continue
}
return msg, err
}
}
// ReadRPCMsg blocks until a DI-Net RPC message is received (other websocket messages are silently dropped)
func (t *Client) ReadRPCMsg() (msg *rpc.Msg, result string, err error) {
d, err := t.Read()
if err != nil {
return nil, "", err
}
return getRPCMsg(d)
}
// Close closes the connection.
// Any blocked Read and Write operations will be unblocked an return errors.
func (t *Client) Close() error {
return t.c.Close()
}
// getRpcMsg converts a byteslice to *rpc.Msg, and raw string result
func getRPCMsg(d []byte) (msg *rpc.Msg, result string, err error) {
res := &json.RawMessage{}
rmsg := &rpc.Msg{Result: res}
if err := json.Unmarshal(d, rmsg); err != nil {
return nil, "", err
}
return rmsg, string(*res), nil
}
// waitForPort wait for a TCP port to become online within the timeout duration
func waitForPort(port string, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", ":"+port, time.Second)
if err == nil {
if err := conn.Close(); err != nil {
logrus.Errorf("close failed: %v", err)
}
return true
}
time.Sleep(10 * time.Millisecond)
}
return false
}