Просмотр исходного кода

add connect timeout in dial function

Joyit 1 месяц назад
Родитель
Сommit
66c446a12a
4 измененных файлов с 57 добавлено и 20 удалено
  1. 1 1
      config/config.go
  2. 1 1
      conn/tcp2/tcp2.go
  3. 7 3
      conn/ws2/ws2.go
  4. 48 15
      hub.go

+ 1 - 1
config/config.go

@@ -43,7 +43,7 @@ type Config struct {
 func NewConfig() *Config {
 	// 配置基础的数据
 	return &Config{
-		ConnectTimeout:       60 * 1000,
+		ConnectTimeout:       30 * 1000,
 		PingInterval:         61 * 1000,
 		WriteWait:            60 * 1000,
 		ReadWait:             30 * 1000,

+ 1 - 1
conn/tcp2/tcp2.go

@@ -122,7 +122,7 @@ func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFu
 }
 
 // 客户端,新建一个连接
-func Client(cf *config.Config, addr string, hash string) (conn.Connect, error) {
+func Dial(cf *config.Config, addr string, hash string) (conn.Connect, error) {
 	// 没有加密的情况
 	if hash == "" {
 		conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond)

+ 7 - 3
conn/ws2/ws2.go

@@ -125,11 +125,13 @@ func Server(cf *config.Config, bind string, path string, hash string, fn conn.Se
 }
 
 // 客户端,新建一个连接
-func Client(cf *config.Config, scheme string, addr string, path string, hash string) (conn.Connect, error) {
+func Dial(cf *config.Config, scheme string, addr string, path string, hash string) (conn.Connect, error) {
 	u := url.URL{Scheme: scheme, Host: addr, Path: path}
 	// 没有加密的情况
 	if hash == "" {
-		conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
+		conn, _, err := (&websocket.Dialer{
+			HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
+		}).Dial(u.String(), nil)
 		if err != nil {
 			return nil, err
 		}
@@ -149,7 +151,9 @@ func Client(cf *config.Config, scheme string, addr string, path string, hash str
 	if !ok {
 		return nil, errors.New("Unsupported encryption method: " + encryptMethod)
 	}
-	conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
+	conn, _, err := (&websocket.Dialer{
+		HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
+	}).Dial(u.String(), nil)
 	if err != nil {
 		return nil, err
 	}

+ 48 - 15
hub.go

@@ -1,6 +1,7 @@
 package tinymq
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"log"
@@ -561,23 +562,55 @@ func (h *Hub) ConnectToServer(channel string, force bool) (err error) {
 	if err != nil {
 		return err
 	}
+
 	var conn conn.Connect
 	var runProto string
 	addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
-	if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
-		runProto = ws2.PROTO
-		conn, err = ws2.Client(h.cf, host.Proto, addr, host.Path, host.Hash)
-	} else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
-		runProto = tcp2.PROTO
-		conn, err = tcp2.Client(h.cf, addr, host.Hash)
-	} else {
-		return fmt.Errorf("not correct protocol and version found in: %+v", host)
-	}
-	if err != nil {
-		log.Println("[Client ERROR]", host.Proto, err)
-		host.Errors++
-		host.Updated = time.Now()
-		return err
+
+	// 添加定时器
+	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
+	defer cancel()
+	taskCh := make(chan bool)
+	done := false
+
+	go func() {
+		if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
+			runProto = ws2.PROTO
+			conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
+		} else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
+			runProto = tcp2.PROTO
+			conn, err = tcp2.Dial(h.cf, addr, host.Hash)
+
+		} else {
+			err = fmt.Errorf("not correct protocol and version found in: %+v", host)
+		}
+		if done {
+			if err != nil {
+				log.Println("[Dial ERROR]", err)
+			}
+			if conn != nil {
+				conn.Close()
+			}
+		} else {
+			taskCh <- err == nil
+		}
+	}()
+
+	select {
+	case ok := <-taskCh:
+		cancel()
+		if !ok || err != nil || conn == nil {
+			log.Println("[Client ERROR]", host.Proto, err)
+			host.Errors++
+			host.Updated = time.Now()
+			if err == nil {
+				err = errors.New("unknown error")
+			}
+			return err
+		}
+	case <-ctx.Done():
+		done = true
+		return errors.New("timeout")
 	}
 	// 发送验证信息
 	if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
@@ -628,7 +661,7 @@ func (h *Hub) ConnectToServer(channel string, force bool) (err error) {
 	host.Updated = time.Now()
 
 	// 将连接加入现有连接中
-	done := false
+	done = false
 	h.connects.Range(func(key, _ any) bool {
 		line := key.(*Line)
 		if line.channel == channel {