package tinymq import ( "context" "errors" "fmt" "log" "math/rand" "net" "regexp" "strconv" "strings" "sync" "time" "git.me9.top/git/tinymq/config" "git.me9.top/git/tinymq/conn" "git.me9.top/git/tinymq/conn/tcp2" "git.me9.top/git/tinymq/conn/ws2" ) // 类似一个插座的功能,管理多个连接 // 一个hub即可以是客户端,同时也可以是服务端 // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。 // 截取部分字符串 func subStr(str string, length int) string { if len(str) <= length { return str } return str[0:length] + "..." } type Hub struct { sync.Mutex cf *config.Config globalID uint16 channel string // 本地频道信息 middle []MiddleFunc // 中间件 connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找 subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求 msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id // 客户端需要用的函数 connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址 authFunc AuthFunc // 获取认证信息,用于发送给对方 // 服务端需要用的函数 checkAuthFunc CheckAuthFunc // 核对认证是否合法 // 连接状态变化时调用的函数 connectStatusFunc ConnectStatusFunc // 上次清理异常连接时间戳 lastCleanDeadConnect int64 } // 清理异常连接 func (h *Hub) cleanDeadConnect() { h.Lock() defer h.Unlock() now := time.Now().UnixMilli() if now-h.lastCleanDeadConnect > int64(h.cf.CleanDeadConnectWait) { h.lastCleanDeadConnect = now h.connects.Range(func(key, _ any) bool { line := key.(*Line) if line.state != Connected && now-line.updated.UnixMilli() > int64(h.cf.CleanDeadConnectWait) { h.connects.Delete(key) } return true }) } } // 获取通讯消息ID号 func (h *Hub) GetID() uint16 { h.Lock() defer h.Unlock() h.globalID++ if h.globalID <= 0 || h.globalID >= config.ID_MAX { h.globalID = 1 } for { // 检查是否在请求队列中存在对应的id if _, ok := h.msgCache.Load(h.globalID); ok { h.globalID++ if h.globalID <= 0 || h.globalID >= config.ID_MAX { h.globalID = 1 } } else { break } } return h.globalID } // 添加中间件 // 如果中间件函数返回为空,表示处理完成,通过 // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过 func (h *Hub) UseMiddle(middleFunc MiddleFunc) { h.middle = append(h.middle, middleFunc) } // 注册频道,其中频道为正则表达式字符串 func (h *Hub) Subscribe(channel *regexp.Regexp, cmd string, backFunc SubscribeBack) (err error) { if channel == nil { return errors.New("channel can not be nil") } reg := &SubscribeData{ Channel: channel, Cmd: cmd, BackFunc: backFunc, } sub, ok := h.subscribes.Load(cmd) if ok { h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg)) return } regs := make([]*SubscribeData, 1) regs[0] = reg h.subscribes.Store(cmd, regs) return } // 获取当前在线的数量 func (h *Hub) ConnectNum() int { var count int h.connects.Range(func(key, _ any) bool { if key.(*Line).state == Connected { count++ } return true }) return count } // 获取所有的在线连接频道 func (h *Hub) AllChannel() []string { cs := make([]string, 0) h.connects.Range(func(key, _ any) bool { line := key.(*Line) if line.state == Connected { cs = append(cs, line.channel) } return true }) return cs } // 获取所有连接频道和连接时长 // 为了避免定义数据结构麻烦,采用|隔开 func (h *Hub) AllChannelTime() []string { cs := make([]string, 0) h.connects.Range(func(key, value any) bool { line := key.(*Line) if line.state == Connected { ti := time.Since(value.(time.Time)).Milliseconds() cs = append(cs, line.channel+"|"+strconv.FormatInt(ti, 10)) } return true }) return cs } // 获取频道并通过函数过滤,如果返回 false 将终止 func (h *Hub) ChannelToFunc(fn func(string) bool) { h.connects.Range(func(key, _ any) bool { line := key.(*Line) if line.state == Connected { return fn(line.channel) } return true }) } // 从 channel 获取连接 func (h *Hub) ChannelToLine(channel string) (line *Line) { h.connects.Range(func(key, _ any) bool { l := key.(*Line) if l.channel == channel { line = l return false } return true }) return } // 返回请求结果 func (h *Hub) outResponse(response *ResponseData) { defer recover() //避免管道已经关闭而引起panic id := response.Id t, ok := h.msgCache.Load(id) if ok { // 删除数据缓存 h.msgCache.Delete(id) gm := t.(*GetMsg) // 停止定时器 if !gm.timer.Stop() { select { case <-gm.timer.C: default: } } // 回应数据到上层 gm.out <- response } } // 发送数据到网络接口 // 返回发送的数量 func (h *Hub) sendRequest(gd *GetData) (count int) { h.connects.Range(func(key, _ any) bool { conn := key.(*Line) // 检查连接是否OK if conn.state != Connected { return true } if gd.Channel.MatchString(conn.channel) { var id uint16 if gd.backchan != nil { id = h.GetID() timeout := gd.Timeout if timeout <= 0 { timeout = h.cf.WriteWait } fn := func(id uint16, conn *Line) func() { return func() { go h.outResponse(&ResponseData{ Id: id, State: config.GET_TIMEOUT, Data: []byte(config.GET_TIMEOUT_MSG), conn: conn, }) // 检查是否已经很久时间没有使用连接了 if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) { // 超时关闭当前的连接 log.Println("get message timeout", conn.channel) // 有可能连接出现问题,断开并重新连接 conn.Close(false) return } } }(id, conn) // 将要发送的请求缓存 gm := &GetMsg{ out: gd.backchan, timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn), } h.msgCache.Store(id, gm) } // 组织数据并发送到Connect conn.sendRequest <- &RequestData{ Id: id, Cmd: gd.Cmd, Data: gd.Data, timeout: gd.Timeout, backchan: gd.backchan, conn: conn, } if h.cf.PrintMsg { log.Println("[SEND]->", conn.channel, "["+gd.Cmd+"]", subStr(string(gd.Data), 200)) } count++ if gd.Max > 0 && count >= gd.Max { return false } } return true }) return } // 执行网络发送过来的命令 func (h *Hub) requestFromNet(request *RequestData) { cmd := request.Cmd channel := request.conn.channel if h.cf.PrintMsg { log.Println("[REQU]<-", channel, "["+cmd+"]", subStr(string(request.Data), 200)) } // 执行中间件 for _, mdFunc := range h.middle { rsp := mdFunc(request) if rsp != nil { // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理 if rsp.State == config.NEXT_MIDDLE { continue } // 返回消息 if request.Id != 0 { rsp.Id = request.Id request.conn.sendResponse <- rsp } return } else { break } } sub, ok := h.subscribes.Load(cmd) if ok { subs := sub.([]*SubscribeData) // 倒序查找是为了新增的频道响应函数优先执行 for i := len(subs) - 1; i >= 0; i-- { rg := subs[i] if rg.Channel.MatchString(channel) { state, data := rg.BackFunc(request) // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理 if state == config.NEXT_SUBSCRIBE { continue } // 如果id为0表示不需要回应 if request.Id != 0 { request.conn.sendResponse <- &ResponseData{ Id: request.Id, State: state, Data: data, } if h.cf.PrintMsg { log.Println("[RESP]->", channel, "["+cmd+"]", state, subStr(string(data), 200)) } } return } } } log.Println("[not match command]", channel, cmd) // 返回没有匹配的消息 request.conn.sendResponse <- &ResponseData{ Id: request.Id, State: config.NO_MATCH, Data: []byte(config.NO_MATCH_MSG), } } // 请求频道并获取数据,采用回调的方式返回结果 // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量 // 如果 backFunc 返回为 false 则提前结束 // 最大数量和超时时间如果为0的话表示使用默认值 func (h *Hub) GetX(channel *regexp.Regexp, cmd string, data []byte, backFunc GetBack, max int, timeout int) (count int) { // 排除空频道 if channel == nil { return 0 } if timeout <= 0 { timeout = h.cf.ReadWait } gd := &GetData{ Channel: channel, Cmd: cmd, Data: data, Max: max, Timeout: timeout, backchan: make(chan *ResponseData, 32), } sendMax := h.sendRequest(gd) if sendMax <= 0 { return 0 } // 避免出现异常时线程无法退出 timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2)) defer func() { if !timer.Stop() { select { case <-timer.C: default: } } close(gd.backchan) }() for { select { case rp := <-gd.backchan: if rp == nil || rp.conn == nil { // 可能是已经退出了 return } ch := rp.conn.channel if h.cf.PrintMsg { log.Println("[RECV]<-", ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200)) } count++ // 如果这里返回为false这跳出循环 if backFunc != nil && !backFunc(rp) { return } if count >= sendMax { return } case <-timer.C: return } } // return } // 请求频道并获取数据,采用回调的方式返回结果 // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量 // 如果 backFunc 返回为 false 则提前结束 func (h *Hub) Get(channel *regexp.Regexp, cmd string, data []byte, backFunc GetBack) (count int) { return h.GetX(channel, cmd, data, backFunc, 0, 0) } // 只获取一个频道的数据,阻塞等待到默认超时间隔 // 如果没有结果将返回 NO_MATCH func (h *Hub) GetOne(channel *regexp.Regexp, cmd string, data []byte) (response *ResponseData) { h.GetX(channel, cmd, data, func(rp *ResponseData) (ok bool) { response = rp return false }, 1, 0) if response == nil { response = &ResponseData{ State: config.NO_MATCH, Data: []byte(config.NO_MATCH_MSG), } } return } // 只获取一个频道的数据,阻塞等待到指定超时间隔 // 如果没有结果将返回 NO_MATCH func (h *Hub) GetOneX(channel *regexp.Regexp, cmd string, data []byte, timeout int) (response *ResponseData) { h.GetX(channel, cmd, data, func(rp *ResponseData) (ok bool) { response = rp return false }, 1, timeout) if response == nil { response = &ResponseData{ State: config.NO_MATCH, Data: []byte(config.NO_MATCH_MSG), } } return } // 推送消息出去,不需要返回数据 func (h *Hub) Push(channel *regexp.Regexp, cmd string, data []byte) { // 排除空频道 if channel == nil { return } gd := &GetData{ Channel: channel, Cmd: cmd, Data: data, Timeout: h.cf.ReadWait, backchan: nil, } h.sendRequest(gd) } // 推送最大对应数量的消息出去,不需要返回数据 func (h *Hub) PushX(channel *regexp.Regexp, cmd string, data []byte, max int) { // 排除空频道 if channel == nil { return } gd := &GetData{ Channel: channel, Cmd: cmd, Data: data, Max: max, Timeout: h.cf.ReadWait, backchan: nil, } h.sendRequest(gd) } // 增加连接 func (h *Hub) addLine(line *Line) { if _, ok := h.connects.Load(line); ok { log.Println("connect have exist") // 连接已经存在,直接返回 return } // 检查是否有相同的channel,如果有的话将其关闭删除 channel := line.channel h.connects.Range(func(key, _ any) bool { conn := key.(*Line) // 删除超时的连接 if conn.state != Connected && conn.host == nil && time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*5*int(time.Millisecond)) { h.connects.Delete(key) return true } if conn.channel == channel { conn.Close(true) h.connects.Delete(key) return false } return true }) h.connects.Store(line, true) } // 删除连接 func (h *Hub) removeLine(conn *Line) { conn.Close(true) h.connects.Delete(conn) } // 获取指定连接的连接持续时间 func (h *Hub) ConnectDuration(conn *Line) time.Duration { t, ok := h.connects.Load(conn) if ok { return time.Since(t.(time.Time)) } // 如果不存在直接返回0 return time.Duration(0) } // 绑定端口,建立服务 // 需要程序运行时调用 func (h *Hub) BindForServer(info *HostInfo) (err error) { doConnectFunc := func(conn conn.Connect) { proto, version, channel, auth, err := conn.ReadAuthInfo() if err != nil { log.Println("[BindForServer ReadAuthInfo ERROR]", err) conn.Close() return } if version != info.Version || proto != info.Proto { log.Println("wrong version or protocol: ", version, proto) conn.Close() return } // 检查验证是否合法 if !h.checkAuthFunc(proto, version, channel, auth) { conn.Close() return } // 发送频道信息 if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil { log.Println("[WriteAuthInfo ERROR]", err) conn.Close() return } // 将连接加入现有连接中 done := false h.connects.Range(func(key, _ any) bool { line := key.(*Line) if line.state == Disconnected && line.channel == channel && line.host == nil { line.Start(conn, nil) done = true return false } return true }) // 新建一个连接 if !done { line := NewConnect(h.cf, h, channel, conn, nil) h.addLine(line) } } if info.Version == ws2.VERSION && info.Proto == ws2.PROTO { bind := "" if info.Bind != "" { bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))) } return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc) } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO { return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc) } return errors.New("not connect protocol and version found") } // 新建一个连接,不同的连接协议由底层自己选择 // channel: 要连接的频道信息,需要能表达频道关键信息的部分 func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err error) { // 检查当前channel是否已经存在 if !force { line := h.ChannelToLine(channel) if line != nil && line.state == Connected { err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel) return } } if host == nil { // 获取服务地址等信息 host, err = h.connectHostFunc(channel, true) if err != nil { return err } } var conn conn.Connect var runProto string addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port))) // 添加定时器 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout)) defer cancel() taskCh := make(chan bool) done := false go func() { if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) { runProto = ws2.PROTO conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash) } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO { runProto = tcp2.PROTO conn, err = tcp2.Dial(h.cf, addr, host.Hash) } else { err = fmt.Errorf("not correct protocol and version found in: %+v", host) } if done { if err != nil { log.Println("[Dial ERROR]", err) } if conn != nil { conn.Close() } } else { taskCh <- err == nil } }() select { case ok := <-taskCh: cancel() if !ok || err != nil || conn == nil { log.Println("[Client ERROR]", host.Proto, err) host.Errors++ host.Updated = time.Now() if err == nil { err = errors.New("unknown error") } return err } case <-ctx.Done(): done = true return errors.New("timeout") } // 发送验证信息 if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil { log.Println("[WriteAuthInfo ERROR]", err) conn.Close() host.Errors++ host.Updated = time.Now() return err } // 接收频道信息 proto, version, channel2, auth, err := conn.ReadAuthInfo() if err != nil { log.Println("[ConnectToServer ReadAuthInfo ERROR]", err) conn.Close() host.Errors++ host.Updated = time.Now() return err } // 检查版本和协议是否一致 if version != host.Version || proto != runProto { err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto) log.Println(err) conn.Close() host.Errors++ host.Updated = time.Now() return err } // 检查频道名称是否匹配 if !strings.Contains(channel2, channel) { err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2) log.Println(err) conn.Close() host.Errors++ host.Updated = time.Now() return err } // 检查验证是否合法 if !h.checkAuthFunc(proto, version, channel, auth) { err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth)) log.Println(err) conn.Close() host.Errors++ host.Updated = time.Now() return err } // 更新服务主机信息 host.Errors = 0 host.Updated = time.Now() // 将连接加入现有连接中 done = false h.connects.Range(func(key, _ any) bool { line := key.(*Line) if line.channel == channel { if line.state == Connected { if force { line.Close(true) } else { err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel) log.Println(err) return false } } line.Start(conn, host) done = true return false } return true }) if err != nil { return err } // 新建一个连接 if !done { line := NewConnect(h.cf, h, channel, conn, host) h.addLine(line) } return nil } // 重试方式连接服务 // 将会一直阻塞直到连接成功 func (h *Hub) ConnectToServerX(channel string, force bool) { host, _ := h.connectHostFunc(channel, false) for { err := h.ConnectToServer(channel, force, host) if err == nil { return } host = nil log.Println("[ConnectToServer ERROR, try it again]", err) // 产生一个随机数避免刹间重连过载 r := rand.New(rand.NewSource(time.Now().UnixNano())) time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond) } } // 检测处理代理连接 func (h *Hub) checkProxyConnect() { if h.cf.ProxyTimeout <= 0 { return } proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond))) for { <-proxyTicker.C now := time.Now().UnixMilli() h.connects.Range(func(key, _ any) bool { line := key.(*Line) if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) { host, err := h.connectHostFunc(line.channel, false) if err != nil { log.Println("[checkProxyConnect connectHostFunc ERROR]", err) return false } err = h.ConnectToServer(line.channel, true, host) if err != nil { log.Println("[checkProxyConnect ConnectToServer WARNING]", err) } } return true }) } } // 建立一个集线器 // connectFunc 用于监听连接状态的函数,可以为nil func NewHub( cf *config.Config, channel string, // 客户端需要用的函数 connectHostFunc ConnectHostFunc, authFunc AuthFunc, // 服务端需要用的函数 checkAuthFunc CheckAuthFunc, // 连接状态变化时调用的函数 connectStatusFunc ConnectStatusFunc, ) (h *Hub) { h = &Hub{ cf: cf, channel: channel, middle: make([]MiddleFunc, 0), connectHostFunc: connectHostFunc, authFunc: authFunc, checkAuthFunc: checkAuthFunc, connectStatusFunc: connectStatusFunc, lastCleanDeadConnect: time.Now().UnixMilli(), } go h.checkProxyConnect() return h }