Joyit 1 місяць тому
батько
коміт
b16b17c105
5 змінених файлів з 40 додано та 5 видалено
  1. 1 0
      README.md
  2. 2 0
      config/config.go
  3. 6 0
      examples/server.go
  4. 22 4
      hub.go
  5. 9 1
      type.go

+ 1 - 0
README.md

@@ -20,6 +20,7 @@
 
 ## 问题与优化
 
+- 增加订阅中间件,处理验证登录等问题
 - 建立内存池来分配内存,减少内存碎片
 - 同地址多连接共存,使用不同的连接发送消息,减少延时,提高消息送达可靠性
 - 转发地址定时测试切换回到主服务节点

+ 2 - 0
config/config.go

@@ -5,6 +5,8 @@ const (
 	MIN_SYSTEM_ERROR_CODE = 110 // 系统信息最小值
 	NEXT_SUBSCRIBE        = 111
 	NEXT_SUBSCRIBE_MSG    = "NEXT SUBSCRIBE"
+	FORBIDDEN             = 120
+	FORBIDDEN_MSG         = "FORBIDDEN"
 	SYSTEM_ERROR          = 123
 	SYSTEM_ERROR_MSG      = "SYSTEM ERROR"
 	GET_TIMEOUT           = 125

+ 6 - 0
examples/server.go

@@ -84,6 +84,12 @@ func main() {
 	}
 	hub.BindForServer(bindInfo)
 
+	// 中间件
+	hub.UseMiddle(func(request *tinymq.RequestData) (response *tinymq.ResponseData) {
+		log.Println("[Middle]", request.Conn().Channel(), request.Cmd)
+		return nil
+	})
+
 	// 订阅频道
 	hub.Subscribe(regexp.MustCompile(remoteChannel), "hello",
 		func(request *tinymq.RequestData) (state uint8, result []byte) {

+ 22 - 4
hub.go

@@ -34,10 +34,11 @@ type Hub struct {
 	sync.Mutex
 	cf         *config.Config
 	globalID   uint16
-	channel    string   // 本地频道信息
-	connects   sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
-	subscribes sync.Map // [cmd]->[]*SubscribeData   //注册绑定频道的函数,用于响应请求
-	msgCache   sync.Map //  map[uint16]*GetMsg //请求的回应记录,key为id
+	channel    string       // 本地频道信息
+	middle     []MiddleFunc // 中间件
+	connects   sync.Map     // map[*Line]bool(true) //记录当前的连接,方便查找
+	subscribes sync.Map     // [cmd]->[]*SubscribeData   //注册绑定频道的函数,用于响应请求
+	msgCache   sync.Map     //  map[uint16]*GetMsg //请求的回应记录,key为id
 
 	// 客户端需要用的函数
 	connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
@@ -92,6 +93,11 @@ func (h *Hub) GetID() uint16 {
 	return h.globalID
 }
 
+// 添加中间件
+func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
+	h.middle = append(h.middle, middleFunc)
+}
+
 // 注册频道,其中频道为正则表达式字符串
 func (h *Hub) Subscribe(channel *regexp.Regexp, cmd string, backFunc SubscribeBack) (err error) {
 	if channel == nil {
@@ -265,6 +271,17 @@ func (h *Hub) requestFromNet(request *RequestData) {
 	cmd := request.Cmd
 	channel := request.conn.channel
 	log.Println("[REQU]<-", channel, "["+cmd+"]", subStr(string(request.Data), 200))
+	// 执行中间件
+	for _, mdFunc := range h.middle {
+		rsp := mdFunc(request)
+		if rsp != nil {
+			if request.Id != 0 {
+				rsp.Id = request.Id
+				request.conn.sendResponse <- rsp
+			}
+			return
+		}
+	}
 	sub, ok := h.subscribes.Load(cmd)
 	if ok {
 		subs := sub.([]*SubscribeData)
@@ -668,6 +685,7 @@ func NewHub(
 	h = &Hub{
 		cf:                   cf,
 		channel:              channel,
+		middle:               make([]MiddleFunc, 0),
 		connectHostFunc:      connectHostFunc,
 		authFunc:             authFunc,
 		checkAuthFunc:        checkAuthFunc,

+ 9 - 1
type.go

@@ -5,9 +5,17 @@ import (
 	"time"
 )
 
+// 中间件函数,如果返回为nil则继续循环到下个中间件,如果返回不为空则直接返回结果
+type MiddleFunc func(request *RequestData) (response *ResponseData)
+
+// 订阅频道响应函数
 type SubscribeBack func(request *RequestData) (state uint8, result []byte)
+
+// GET 获取数据的回调函数,如果返回 false 则提前结束
 type GetBack func(response *ResponseData) (ok bool)
-type ConnectStatusFunc func(conn *Line) // 线路状态改变时调用
+
+// 线路状态改变时调用
+type ConnectStatusFunc func(conn *Line)
 
 // 订阅频道数据结构
 type SubscribeData struct {