Эх сурвалжийг харах

add data type in function

Joyit 2 долоо хоног өмнө
parent
commit
034462ff9f
3 өөрчлөгдсөн 51 нэмэгдсэн , 41 устгасан
  1. 9 1
      examples/client-ws2.go
  2. 38 39
      hub.go
  3. 4 1
      type.go

+ 9 - 1
examples/client-ws2.go

@@ -62,7 +62,15 @@ func main() {
 	}
 	}
 
 
 	// 获取信息
 	// 获取信息
-	rsp := hub.GetOne(remoteFilter, "hello", []byte("hello from client,hello from client,hello from client"))
+	rsp := hub.GetOne(remoteFilter, "hello", "hello from client,hello from client,hello from client")
+	if rsp.State != config.STATE_OK {
+		log.Println("error state:", rsp.State)
+		return
+	}
+	log.Println("[RESULT]<-", string(rsp.Data))
+	rsp = hub.GetOne(remoteFilter, "hello", func() ([]byte, error) {
+		return []byte("hello from data function"), nil
+	})
 	if rsp.State != config.STATE_OK {
 	if rsp.State != config.STATE_OK {
 		log.Println("error state:", rsp.State)
 		log.Println("error state:", rsp.State)
 		return
 		return

+ 38 - 39
hub.go

@@ -56,6 +56,35 @@ type Hub struct {
 	lastCleanDeadConnect int64
 	lastCleanDeadConnect int64
 }
 }
 
 
+// 转换数据
+func (h *Hub) convData(data any) (reqData []byte, err error) {
+	switch data := data.(type) {
+	case []byte:
+		reqData = data
+	case string:
+		reqData = []byte(data)
+	case func() ([]byte, error):
+		reqData, err = data()
+		if err != nil {
+			log.Println(err.Error())
+			return nil, err
+		}
+	default:
+		if data != nil {
+			// if s, ok := data.(func() ([]uint8, error)); ok {
+			// 	return s()
+			// }
+			// 自动转换数据为json格式
+			reqData, err = json.Marshal(data)
+			if err != nil {
+				log.Println(err.Error())
+				return nil, err
+			}
+		}
+	}
+	return
+}
+
 // 清理异常连接
 // 清理异常连接
 func (h *Hub) cleanDeadConnect() {
 func (h *Hub) cleanDeadConnect() {
 	h.Lock()
 	h.Lock()
@@ -220,6 +249,11 @@ func (h *Hub) outResponse(response *ResponseData) {
 // 发送数据到网络接口
 // 发送数据到网络接口
 // 返回发送的数量
 // 返回发送的数量
 func (h *Hub) sendRequest(gd *GetData) (count int) {
 func (h *Hub) sendRequest(gd *GetData) (count int) {
+	outData, err := h.convData(gd.Data)
+	if err != nil {
+		log.Println(err)
+		return 0
+	}
 	h.connects.Range(func(key, _ any) bool {
 	h.connects.Range(func(key, _ any) bool {
 		conn := key.(*Line)
 		conn := key.(*Line)
 		// 检查连接是否OK
 		// 检查连接是否OK
@@ -264,13 +298,13 @@ func (h *Hub) sendRequest(gd *GetData) (count int) {
 			conn.sendRequest <- &RequestData{
 			conn.sendRequest <- &RequestData{
 				Id:       id,
 				Id:       id,
 				Cmd:      gd.Cmd,
 				Cmd:      gd.Cmd,
-				Data:     gd.Data,
+				Data:     outData,
 				timeout:  gd.Timeout,
 				timeout:  gd.Timeout,
 				backchan: gd.backchan,
 				backchan: gd.backchan,
 				conn:     conn,
 				conn:     conn,
 			}
 			}
 			if h.cf.PrintMsg {
 			if h.cf.PrintMsg {
-				log.Println("[SEND]->", id, conn.channel, "["+gd.Cmd+"]", subStr(string(gd.Data), 200))
+				log.Println("[SEND]->", id, conn.channel, "["+gd.Cmd+"]", subStr(string(outData), 200))
 			}
 			}
 			count++
 			count++
 			if gd.Max > 0 && count >= gd.Max {
 			if gd.Max > 0 && count >= gd.Max {
@@ -372,31 +406,13 @@ func (h *Hub) GetWithMaxAndTimeout(filter FilterFunc, cmd string, data any, back
 	if filter == nil {
 	if filter == nil {
 		return 0
 		return 0
 	}
 	}
-	var reqData []byte
-	switch data := data.(type) {
-	case []byte:
-		reqData = data
-	case string:
-		reqData = []byte(data)
-	default:
-		if data != nil {
-			// 自动转换数据为json格式
-			var err error
-			reqData, err = json.Marshal(data)
-			if err != nil {
-				log.Println(err.Error())
-				return 0
-			}
-		}
-	}
-
 	if timeout <= 0 {
 	if timeout <= 0 {
 		timeout = h.cf.ReadWait
 		timeout = h.cf.ReadWait
 	}
 	}
 	gd := &GetData{
 	gd := &GetData{
 		Filter:   filter,
 		Filter:   filter,
 		Cmd:      cmd,
 		Cmd:      cmd,
-		Data:     reqData,
+		Data:     data,
 		Max:      max,
 		Max:      max,
 		Timeout:  timeout,
 		Timeout:  timeout,
 		backchan: make(chan *ResponseData, 32),
 		backchan: make(chan *ResponseData, 32),
@@ -492,27 +508,10 @@ func (h *Hub) PushWithMax(filter FilterFunc, cmd string, data any, max int) {
 	if filter == nil {
 	if filter == nil {
 		return
 		return
 	}
 	}
-	var reqData []byte
-	switch data := data.(type) {
-	case []byte:
-		reqData = data
-	case string:
-		reqData = []byte(data)
-	default:
-		if data != nil {
-			// 自动转换数据为json格式
-			var err error
-			reqData, err = json.Marshal(data)
-			if err != nil {
-				log.Println(err.Error())
-				return
-			}
-		}
-	}
 	gd := &GetData{
 	gd := &GetData{
 		Filter:   filter,
 		Filter:   filter,
 		Cmd:      cmd,
 		Cmd:      cmd,
-		Data:     reqData,
+		Data:     data,
 		Max:      max,
 		Max:      max,
 		Timeout:  h.cf.ReadWait,
 		Timeout:  h.cf.ReadWait,
 		backchan: nil,
 		backchan: nil,

+ 4 - 1
type.go

@@ -25,6 +25,9 @@ type ConnectStatusFunc func(conn *Line)
 // 频道过滤器函数,如果返回true表示成功匹配
 // 频道过滤器函数,如果返回true表示成功匹配
 type FilterFunc func(conn *Line) (ok bool)
 type FilterFunc func(conn *Line) (ok bool)
 
 
+// 数据获取函数,为了应对需要延迟获取数据的情况
+// type DataFunc func() ([]byte, error)
+
 // 订阅频道数据结构
 // 订阅频道数据结构
 type SubscribeData struct {
 type SubscribeData struct {
 	// Channel  *regexp.Regexp    //频道的正则表达式
 	// Channel  *regexp.Regexp    //频道的正则表达式
@@ -38,7 +41,7 @@ type GetData struct {
 	// Channel *regexp.Regexp
 	// Channel *regexp.Regexp
 	Filter  FilterFunc // 命令匹配过滤
 	Filter  FilterFunc // 命令匹配过滤
 	Cmd     string
 	Cmd     string
-	Data    []byte
+	Data    any
 	Max     int // 获取数据的频道最多有几个,如果为0表示没有限制
 	Max     int // 获取数据的频道最多有几个,如果为0表示没有限制
 	Timeout int // 超时时间(毫秒)
 	Timeout int // 超时时间(毫秒)