|
@@ -1,6 +1,7 @@
|
|
|
package tinymq
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"log"
|
|
@@ -561,23 +562,55 @@ func (h *Hub) ConnectToServer(channel string, force bool) (err error) {
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+
|
|
|
var conn conn.Connect
|
|
|
var runProto string
|
|
|
addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
|
|
|
- if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
|
|
|
- runProto = ws2.PROTO
|
|
|
- conn, err = ws2.Client(h.cf, host.Proto, addr, host.Path, host.Hash)
|
|
|
- } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
|
|
|
- runProto = tcp2.PROTO
|
|
|
- conn, err = tcp2.Client(h.cf, addr, host.Hash)
|
|
|
- } else {
|
|
|
- return fmt.Errorf("not correct protocol and version found in: %+v", host)
|
|
|
- }
|
|
|
- if err != nil {
|
|
|
- log.Println("[Client ERROR]", host.Proto, err)
|
|
|
- host.Errors++
|
|
|
- host.Updated = time.Now()
|
|
|
- return err
|
|
|
+
|
|
|
+ // 添加定时器
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
|
|
|
+ defer cancel()
|
|
|
+ taskCh := make(chan bool)
|
|
|
+ done := false
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
|
|
|
+ runProto = ws2.PROTO
|
|
|
+ conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
|
|
|
+ } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
|
|
|
+ runProto = tcp2.PROTO
|
|
|
+ conn, err = tcp2.Dial(h.cf, addr, host.Hash)
|
|
|
+
|
|
|
+ } else {
|
|
|
+ err = fmt.Errorf("not correct protocol and version found in: %+v", host)
|
|
|
+ }
|
|
|
+ if done {
|
|
|
+ if err != nil {
|
|
|
+ log.Println("[Dial ERROR]", err)
|
|
|
+ }
|
|
|
+ if conn != nil {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ taskCh <- err == nil
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case ok := <-taskCh:
|
|
|
+ cancel()
|
|
|
+ if !ok || err != nil || conn == nil {
|
|
|
+ log.Println("[Client ERROR]", host.Proto, err)
|
|
|
+ host.Errors++
|
|
|
+ host.Updated = time.Now()
|
|
|
+ if err == nil {
|
|
|
+ err = errors.New("unknown error")
|
|
|
+ }
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ case <-ctx.Done():
|
|
|
+ done = true
|
|
|
+ return errors.New("timeout")
|
|
|
}
|
|
|
// 发送验证信息
|
|
|
if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
|
|
@@ -628,7 +661,7 @@ func (h *Hub) ConnectToServer(channel string, force bool) (err error) {
|
|
|
host.Updated = time.Now()
|
|
|
|
|
|
// 将连接加入现有连接中
|
|
|
- done := false
|
|
|
+ done = false
|
|
|
h.connects.Range(func(key, _ any) bool {
|
|
|
line := key.(*Line)
|
|
|
if line.channel == channel {
|