Browse Source

add filter to channel function

Joyit 1 month ago
parent
commit
65848344d0
6 changed files with 46 additions and 21 deletions
  1. 8 5
      examples/client-tcp2.go
  2. 0 1
      examples/client-ws2.go
  3. 0 1
      examples/server.go
  4. 34 13
      hub.go
  5. 3 0
      type.go
  6. 1 1
      type_test.go

+ 8 - 5
examples/client-tcp2.go

@@ -41,9 +41,12 @@ func main() {
 		func(conn *tinymq.Line) {
 			log.Println("connect state", conn.Channel(), conn.State(), time.Since(conn.Updated()))
 		},
-		nil,
 	)
 
+	hub.SetFilterToChannelFunc(func(filter tinymq.FilterFunc) (channel string) {
+		return remoteChannel
+	})
+
 	// 订阅频道
 	hub.Subscribe(remoteFilter, "hello", func(request *tinymq.RequestData) (state uint8, result any) {
 		log.Println("[client RECV]<-", string(request.Data))
@@ -56,10 +59,10 @@ func main() {
 	},
 	)
 
-	err := hub.ConnectToServer("/tinymq/server", true, nil)
-	if err != nil {
-		log.Fatalln("[client ConnectToServer ERROR]", err)
-	}
+	// err := hub.ConnectToServer("/tinymq/server", true, nil)
+	// if err != nil {
+	// 	log.Fatalln("[client ConnectToServer ERROR]", err)
+	// }
 
 	// 获取信息
 	rsp := hub.GetOne(remoteFilter, "hello", []byte("hello from client, hello from client, hello from client"))

+ 0 - 1
examples/client-ws2.go

@@ -44,7 +44,6 @@ func main() {
 		func(conn *tinymq.Line) {
 			log.Println("connect state", conn.Channel(), conn.State(), time.Since(conn.Updated()))
 		},
-		nil,
 	)
 
 	// 订阅频道

+ 0 - 1
examples/server.go

@@ -49,7 +49,6 @@ func main() {
 					})
 			}
 		},
-		nil,
 	)
 
 	// tcp2协议

+ 34 - 13
hub.go

@@ -56,13 +56,24 @@ type Hub struct {
 	// 连接状态变化时调用的函数
 	connectStatusFunc ConnectStatusFunc
 
-	// 验证发送数据的条件是否满足 (可空)
+	// 验证发送数据的条件是否满足 (可空)
 	checkConnectOkFunc CheckConnectOkFunc
 
+	// 通过过滤函数获取一个频道信息 (可空)
+	filterToChannelFunc FilterToChannelFunc
+
 	// 上次清理异常连接时间戳
 	lastCleanDeadConnect int64
 }
 
+func (h *Hub) SetCheckConnectOkFunc(fn CheckConnectOkFunc) {
+	h.checkConnectOkFunc = fn
+}
+
+func (h *Hub) SetFilterToChannelFunc(fn FilterToChannelFunc) {
+	h.filterToChannelFunc = fn
+}
+
 // 转换数据
 func (h *Hub) convertData(data any) (reqData []byte, err error) {
 	switch data := data.(type) {
@@ -314,7 +325,17 @@ func (h *Hub) sendRequest(gd *GetData) (count int) {
 		if count > 0 {
 			break
 		}
-		time.Sleep(time.Millisecond * 500)
+		// 如果是客户端,并且有机会自动连接,则尝试自动连接
+		if i == 0 && h.connectHostFunc != nil && h.filterToChannelFunc != nil {
+			channel := h.filterToChannelFunc(gd.Filter)
+			err := h.ConnectToServer(channel, false, nil)
+			if err != nil {
+				log.Println(err)
+				return 0
+			}
+		} else {
+			time.Sleep(time.Millisecond * 500)
+		}
 	}
 	return
 }
@@ -860,22 +881,22 @@ func NewHub(
 	// 连接状态变化时调用的函数
 	connectStatusFunc ConnectStatusFunc,
 	// 验证发送数据的条件是否满足 (可为空)
-	checkConnectOkFunc CheckConnectOkFunc,
+	// checkConnectOkFunc CheckConnectOkFunc,
 ) (h *Hub) {
 	if cf == nil {
 		cf = config.NewConfig()
 	}
 	h = &Hub{
-		cf:                   cf,
-		globalID:             uint16(time.Now().UnixNano()) % config.ID_MAX,
-		channel:              channel,
-		middle:               make([]MiddleFunc, 0),
-		lines:                NewMapx(),
-		connectHostFunc:      connectHostFunc,
-		authFunc:             authFunc,
-		checkAuthFunc:        checkAuthFunc,
-		connectStatusFunc:    connectStatusFunc,
-		checkConnectOkFunc:   checkConnectOkFunc,
+		cf:                cf,
+		globalID:          uint16(time.Now().UnixNano()) % config.ID_MAX,
+		channel:           channel,
+		middle:            make([]MiddleFunc, 0),
+		lines:             NewMapx(),
+		connectHostFunc:   connectHostFunc,
+		authFunc:          authFunc,
+		checkAuthFunc:     checkAuthFunc,
+		connectStatusFunc: connectStatusFunc,
+		// checkConnectOkFunc:   checkConnectOkFunc,
 		lastCleanDeadConnect: time.Now().UnixMilli(),
 	}
 	go h.checkConnect()

+ 3 - 0
type.go

@@ -29,6 +29,9 @@ type ConnectStatusFunc func(conn *Line)
 // 频道过滤器函数,如果返回true表示成功匹配
 type FilterFunc func(conn *Line) (ok bool)
 
+// 通过过滤函数获取一个频道信息
+type FilterToChannelFunc func(filter FilterFunc) (channel string)
+
 // 订阅频道数据结构
 type SubscribeData struct {
 	Filter   FilterFunc        // 频道匹配过滤

+ 1 - 1
type_test.go

@@ -3,7 +3,7 @@ package tinymq
 import "testing"
 
 func TestParseUrl(t *testing.T) {
-	url := "ws2://xor:s^7mv7L!Mrn8Y!vn@127.0.0.1:141/wsv2?proxy=1"
+	url := "ws2://xor:s^7mv7L!Mrn8Y!vn@127.0.0.1:141/wsv2?proxy=1&priority=90"
 	hostInfo, err := ParseUrl(url)
 	if err != nil {
 		t.Error(err)