|
@@ -0,0 +1,609 @@
|
|
|
|
+package tinymq
|
|
|
|
+
|
|
|
|
+import (
|
|
|
|
+ "errors"
|
|
|
|
+ "fmt"
|
|
|
|
+ "log"
|
|
|
|
+ "math/rand"
|
|
|
|
+ "net"
|
|
|
|
+ "strconv"
|
|
|
|
+ "strings"
|
|
|
|
+ "sync"
|
|
|
|
+ "time"
|
|
|
|
+
|
|
|
|
+ "git.me9.top/git/tinymq/config"
|
|
|
|
+ "git.me9.top/git/tinymq/conn"
|
|
|
|
+ "git.me9.top/git/tinymq/conn/tpv2"
|
|
|
|
+ "git.me9.top/git/tinymq/conn/wsv2"
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+// 类似一个插座的功能,管理多个连接
|
|
|
|
+// 一个hub即可以是客户端,同时也可以是服务端
|
|
|
|
+// 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
|
|
|
|
+
|
|
|
|
+// 截取部分字符串
|
|
|
|
+func subStr(str string, length int) string {
|
|
|
|
+ if len(str) <= length {
|
|
|
|
+ return str
|
|
|
|
+ }
|
|
|
|
+ return str[0:length] + "..."
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+type Hub struct {
|
|
|
|
+ sync.Mutex
|
|
|
|
+ cf *config.Config
|
|
|
|
+ globalID uint16
|
|
|
|
+ channel string // 本地频道信息
|
|
|
|
+ connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
|
|
|
|
+ subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
|
|
|
|
+ msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
|
|
|
|
+
|
|
|
|
+ // 客户端需要用的函数
|
|
|
|
+ connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
|
|
|
|
+ authFunc AuthFunc // 获取认证信息,用于发送给对方
|
|
|
|
+
|
|
|
|
+ // 服务端需要用的函数
|
|
|
|
+ checkAuthFunc CheckAuthFunc // 核对认证是否合法
|
|
|
|
+
|
|
|
|
+ // 连接状态变化时调用的函数
|
|
|
|
+ connectStatusFunc ConnectStatusFunc
|
|
|
|
+
|
|
|
|
+ // 上次清理异常连接时间戳
|
|
|
|
+ lastCleanDeadConnect int64
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 清理异常连接
|
|
|
|
+func (h *Hub) cleanDeadConnect() {
|
|
|
|
+ h.Lock()
|
|
|
|
+ defer h.Unlock()
|
|
|
|
+ now := time.Now().UnixMilli()
|
|
|
|
+ if now-h.lastCleanDeadConnect > int64(h.cf.CleanDeadConnectWait) {
|
|
|
|
+ h.lastCleanDeadConnect = now
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ line := key.(*Line)
|
|
|
|
+ if line.state != Connected && now-line.updated.UnixMilli() > int64(h.cf.CleanDeadConnectWait) {
|
|
|
|
+ h.connects.Delete(key)
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 获取通讯消息ID号
|
|
|
|
+func (h *Hub) GetID() uint16 {
|
|
|
|
+ h.Lock()
|
|
|
|
+ defer h.Unlock()
|
|
|
|
+ h.globalID++
|
|
|
|
+ if h.globalID <= 0 || h.globalID >= config.ID_MAX {
|
|
|
|
+ h.globalID = 1
|
|
|
|
+ }
|
|
|
|
+ for {
|
|
|
|
+ // 检查是否在请求队列中存在对应的id
|
|
|
|
+ if _, ok := h.msgCache.Load(h.globalID); ok {
|
|
|
|
+ h.globalID++
|
|
|
|
+ if h.globalID <= 0 || h.globalID >= config.ID_MAX {
|
|
|
|
+ h.globalID = 1
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ break
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return h.globalID
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 注册频道,其中频道为正则表达式字符串
|
|
|
|
+func (h *Hub) Subscribe(reg *SubscribeData) (err error) {
|
|
|
|
+ if reg.Channel == nil {
|
|
|
|
+ return errors.New("channel can not be nil")
|
|
|
|
+ }
|
|
|
|
+ cmd := reg.Cmd
|
|
|
|
+ sub, ok := h.subscribes.Load(cmd)
|
|
|
|
+ if ok {
|
|
|
|
+ h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ regs := make([]*SubscribeData, 1)
|
|
|
|
+ regs[0] = reg
|
|
|
|
+ h.subscribes.Store(cmd, regs)
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 获取当前在线的数量
|
|
|
|
+func (h *Hub) ConnectNum() int {
|
|
|
|
+ var count int
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ if key.(*Line).state == Connected {
|
|
|
|
+ count++
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ return count
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 获取所有的在线连接频道
|
|
|
|
+func (h *Hub) AllChannel() []string {
|
|
|
|
+ cs := make([]string, 0)
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ line := key.(*Line)
|
|
|
|
+ if line.state == Connected {
|
|
|
|
+ cs = append(cs, line.channel)
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ return cs
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 获取所有连接频道和连接时长
|
|
|
|
+// 为了避免定义数据结构麻烦,采用|隔开
|
|
|
|
+func (h *Hub) AllChannelTime() []string {
|
|
|
|
+ cs := make([]string, 0)
|
|
|
|
+ h.connects.Range(func(key, value any) bool {
|
|
|
|
+ line := key.(*Line)
|
|
|
|
+ if line.state == Connected {
|
|
|
|
+ ti := time.Since(value.(time.Time)).Milliseconds()
|
|
|
|
+ cs = append(cs, line.channel+"|"+strconv.FormatInt(ti, 10))
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ return cs
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 获取频道并通过函数过滤,如果返回 false 将终止
|
|
|
|
+func (h *Hub) ChannelToFunc(fn func(string) bool) {
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ line := key.(*Line)
|
|
|
|
+ if line.state == Connected {
|
|
|
|
+ return fn(line.channel)
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 从 channel 获取连接
|
|
|
|
+func (h *Hub) ChannelToLine(channel string) (line *Line) {
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ l := key.(*Line)
|
|
|
|
+ if l.channel == channel {
|
|
|
|
+ line = l
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 返回请求结果
|
|
|
|
+func (h *Hub) outResponse(response *ResponseData) {
|
|
|
|
+ defer recover() //避免管道已经关闭而引起panic
|
|
|
|
+ id := response.Id
|
|
|
|
+ t, ok := h.msgCache.Load(id)
|
|
|
|
+ if ok {
|
|
|
|
+ // 删除数据缓存
|
|
|
|
+ h.msgCache.Delete(id)
|
|
|
|
+ gm := t.(*GetMsg)
|
|
|
|
+ // 停止定时器
|
|
|
|
+ if !gm.timer.Stop() {
|
|
|
|
+ select {
|
|
|
|
+ case <-gm.timer.C:
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // 回应数据到上层
|
|
|
|
+ gm.out <- response
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 发送数据到网络接口
|
|
|
|
+// 返回发送的数量
|
|
|
|
+func (h *Hub) sendRequest(gd *GetData) (count int) {
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ conn := key.(*Line)
|
|
|
|
+ // 检查连接是否OK
|
|
|
|
+ if conn.state != Connected {
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+ if gd.Channel.MatchString(conn.channel) {
|
|
|
|
+ var id uint16
|
|
|
|
+ if gd.backchan != nil {
|
|
|
|
+ id = h.GetID()
|
|
|
|
+ timeout := gd.Timeout
|
|
|
|
+ if timeout <= 0 {
|
|
|
|
+ timeout = h.cf.WriteWait
|
|
|
|
+ }
|
|
|
|
+ fn := func(id uint16, conn *Line) func() {
|
|
|
|
+ return func() {
|
|
|
|
+ go h.outResponse(&ResponseData{
|
|
|
|
+ Id: id,
|
|
|
|
+ State: config.GET_TIMEOUT,
|
|
|
|
+ Data: []byte(config.GET_TIMEOUT_MSG),
|
|
|
|
+ conn: conn,
|
|
|
|
+ })
|
|
|
|
+ // 检查是否已经很久时间没有使用连接了
|
|
|
|
+ if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3)*time.Millisecond {
|
|
|
|
+ // 超时关闭当前的连接
|
|
|
|
+ log.Println("get message timeout", conn.channel)
|
|
|
|
+ // 有可能连接出现问题,断开并重新连接
|
|
|
|
+ conn.Close(false)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }(id, conn)
|
|
|
|
+ // 将要发送的请求缓存
|
|
|
|
+ gm := &GetMsg{
|
|
|
|
+ out: gd.backchan,
|
|
|
|
+ timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
|
|
|
|
+ }
|
|
|
|
+ h.msgCache.Store(id, gm)
|
|
|
|
+ }
|
|
|
|
+ // 组织数据并发送到Connect
|
|
|
|
+ conn.sendRequest <- &RequestData{
|
|
|
|
+ Id: id,
|
|
|
|
+ Cmd: gd.Cmd,
|
|
|
|
+ Data: gd.Data,
|
|
|
|
+ timeout: gd.Timeout,
|
|
|
|
+ backchan: gd.backchan,
|
|
|
|
+ conn: conn,
|
|
|
|
+ }
|
|
|
|
+ log.Println("[SEND]->", conn.channel, "["+gd.Cmd+"]", subStr(string(gd.Data), 200))
|
|
|
|
+ count++
|
|
|
|
+ if gd.Max > 0 && count >= gd.Max {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 执行网络发送过来的命令
|
|
|
|
+func (h *Hub) requestFromNet(request *RequestData) {
|
|
|
|
+ cmd := request.Cmd
|
|
|
|
+ channel := request.conn.channel
|
|
|
|
+ log.Println("[REQU]<-", channel, "["+cmd+"]", subStr(string(request.Data), 200))
|
|
|
|
+ sub, ok := h.subscribes.Load(cmd)
|
|
|
|
+ if ok {
|
|
|
|
+ subs := sub.([]*SubscribeData)
|
|
|
|
+ // 倒序查找是为了新增的频道响应函数优先执行
|
|
|
|
+ for i := len(subs) - 1; i >= 0; i-- {
|
|
|
|
+ rg := subs[i]
|
|
|
|
+ if rg.Channel.MatchString(channel) {
|
|
|
|
+ state, data := rg.BackFunc(request)
|
|
|
|
+ // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
|
|
|
|
+ if state == config.NEXT_SUBSCRIBE {
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
+ // 如果id为0表示不需要回应
|
|
|
|
+ if request.Id != 0 {
|
|
|
|
+ request.conn.sendResponse <- &ResponseData{
|
|
|
|
+ Id: request.Id,
|
|
|
|
+ State: state,
|
|
|
|
+ Data: data,
|
|
|
|
+ }
|
|
|
|
+ log.Println("[RESP]->", channel, "["+cmd+"]", state, subStr(string(data), 200))
|
|
|
|
+ }
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ log.Println("[not match command]", channel, cmd)
|
|
|
|
+ // 返回没有匹配的消息
|
|
|
|
+ request.conn.sendResponse <- &ResponseData{
|
|
|
|
+ Id: request.Id,
|
|
|
|
+ State: config.NO_MATCH,
|
|
|
|
+ Data: []byte(config.NO_MATCH_MSG),
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 请求频道并获取数据,采用回调的方式返回结果
|
|
|
|
+// 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
|
|
|
|
+// 如果 backFunc 返回为 false 则提前结束
|
|
|
|
+func (h *Hub) Get(gd *GetData, backFunc GetBack) (count int) {
|
|
|
|
+ // 排除空频道
|
|
|
|
+ if gd.Channel == nil {
|
|
|
|
+ return 0
|
|
|
|
+ }
|
|
|
|
+ if gd.Timeout <= 0 {
|
|
|
|
+ gd.Timeout = h.cf.ReadWait
|
|
|
|
+ }
|
|
|
|
+ if gd.backchan == nil {
|
|
|
|
+ gd.backchan = make(chan *ResponseData, 32)
|
|
|
|
+ }
|
|
|
|
+ max := h.sendRequest(gd)
|
|
|
|
+ if max <= 0 {
|
|
|
|
+ return 0
|
|
|
|
+ }
|
|
|
|
+ // 避免出现异常时线程无法退出
|
|
|
|
+ timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
|
|
|
|
+ defer func() {
|
|
|
|
+ if !timer.Stop() {
|
|
|
|
+ select {
|
|
|
|
+ case <-timer.C:
|
|
|
|
+ default:
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ close(gd.backchan)
|
|
|
|
+ }()
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case rp := <-gd.backchan:
|
|
|
|
+ if rp == nil || rp.conn == nil {
|
|
|
|
+ // 可能是已经退出了
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ ch := rp.conn.channel
|
|
|
|
+ log.Println("[RECV]<-", ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
|
|
|
|
+ count++
|
|
|
|
+ // 如果这里返回为false这跳出循环
|
|
|
|
+ if backFunc != nil && !backFunc(rp) {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ if count >= max {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ case <-timer.C:
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 只获取一个频道的数据,阻塞等待
|
|
|
|
+// 如果没有结果将返回 NO_MATCH
|
|
|
|
+func (h *Hub) GetOne(cmd *GetData) (response *ResponseData) {
|
|
|
|
+ cmd.Max = 1
|
|
|
|
+ h.Get(cmd, func(rp *ResponseData) (ok bool) {
|
|
|
|
+ response = rp
|
|
|
|
+ return false
|
|
|
|
+ })
|
|
|
|
+ if response == nil {
|
|
|
|
+ response = &ResponseData{
|
|
|
|
+ State: config.NO_MATCH,
|
|
|
|
+ Data: []byte(config.NO_MATCH_MSG),
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 推送消息出去,不需要返回数据
|
|
|
|
+func (h *Hub) Push(cmd *GetData) {
|
|
|
|
+ cmd.backchan = nil
|
|
|
|
+ h.sendRequest(cmd)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 增加连接
|
|
|
|
+func (h *Hub) addLine(line *Line) {
|
|
|
|
+ if _, ok := h.connects.Load(line); ok {
|
|
|
|
+ log.Println("connect have exist")
|
|
|
|
+ // 连接已经存在,直接返回
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ // 检查是否有相同的channel,如果有的话将其关闭删除
|
|
|
|
+ channel := line.channel
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ conn := key.(*Line)
|
|
|
|
+ // 删除超时的连接
|
|
|
|
+ if conn.state != Connected && conn.host == nil && time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*5)*time.Millisecond {
|
|
|
|
+ h.connects.Delete(key)
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+ if conn.channel == channel {
|
|
|
|
+ conn.Close(true)
|
|
|
|
+ h.connects.Delete(key)
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ h.connects.Store(line, true)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 删除连接
|
|
|
|
+func (h *Hub) removeLine(conn *Line) {
|
|
|
|
+ conn.Close(true)
|
|
|
|
+ h.connects.Delete(conn)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 获取指定连接的连接持续时间
|
|
|
|
+func (h *Hub) ConnectDuration(conn *Line) time.Duration {
|
|
|
|
+ t, ok := h.connects.Load(conn)
|
|
|
|
+ if ok {
|
|
|
|
+ return time.Since(t.(time.Time))
|
|
|
|
+ }
|
|
|
|
+ // 如果不存在直接返回0
|
|
|
|
+ return time.Duration(0)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 绑定端口,建立服务
|
|
|
|
+// 需要程序运行时调用
|
|
|
|
+func (h *Hub) BindForServer(info *HostInfo) (err error) {
|
|
|
|
+ doConnectFunc := func(conn conn.Connect) {
|
|
|
|
+ proto, version, channel, auth, err := conn.ReadAuthInfo()
|
|
|
|
+ if err != nil {
|
|
|
|
+ log.Println("[BindForServer ReadAuthInfo ERROR]", err)
|
|
|
|
+ conn.Close()
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ if version != info.Version || proto != info.Proto {
|
|
|
|
+ log.Println("wrong version or protocol: ", version, proto)
|
|
|
|
+ conn.Close()
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ // 检查验证是否合法
|
|
|
|
+ if !h.checkAuthFunc(proto, version, channel, auth) {
|
|
|
|
+ conn.Close()
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ // 发送频道信息
|
|
|
|
+ if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
|
|
|
|
+ log.Println("[WriteAuthInfo ERROR]", err)
|
|
|
|
+ conn.Close()
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ // 将连接加入现有连接中
|
|
|
|
+ done := false
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ line := key.(*Line)
|
|
|
|
+ if line.state == Disconnected && line.channel == channel && line.host == nil {
|
|
|
|
+ line.Start(conn, nil)
|
|
|
|
+ done = true
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ // 新建一个连接
|
|
|
|
+ if !done {
|
|
|
|
+ line := NewConnect(h.cf, h, channel, conn, nil)
|
|
|
|
+ h.addLine(line)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ if info.Version == wsv2.VERSION && info.Proto == wsv2.PROTO {
|
|
|
|
+ bind := ""
|
|
|
|
+ if info.Bind != "" {
|
|
|
|
+ bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
|
|
|
|
+ }
|
|
|
|
+ return wsv2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
|
|
|
|
+ } else if info.Version == tpv2.VERSION && info.Proto == tpv2.PROTO {
|
|
|
|
+ return tpv2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
|
|
|
|
+ }
|
|
|
|
+ return errors.New("not connect protocol and version found")
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 新建一个连接,不同的连接协议由底层自己选择
|
|
|
|
+// channel: 要连接的频道信息,需要能表达频道关键信息的部分
|
|
|
|
+func (h *Hub) ConnectToServer(channel string, force bool) (err error) {
|
|
|
|
+ // 检查当前channel是否已经存在
|
|
|
|
+ if !force {
|
|
|
|
+ line := h.ChannelToLine(channel)
|
|
|
|
+ if line != nil && line.state == Connected {
|
|
|
|
+ err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // 获取服务地址等信息
|
|
|
|
+ host, err := h.connectHostFunc(channel)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ var conn conn.Connect
|
|
|
|
+ addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
|
|
|
|
+ if host.Version == wsv2.VERSION && host.Proto == wsv2.PROTO {
|
|
|
|
+ conn, err = wsv2.Client(h.cf, addr, host.Path, host.Hash)
|
|
|
|
+ } else if host.Version == tpv2.VERSION && host.Proto == tpv2.PROTO {
|
|
|
|
+ conn, err = tpv2.Client(h.cf, addr, host.Hash)
|
|
|
|
+ } else {
|
|
|
|
+ return fmt.Errorf("not correct protocol and version found in: %+v", host)
|
|
|
|
+ }
|
|
|
|
+ if err != nil {
|
|
|
|
+ log.Println("[Client ERROR]", host.Proto, err)
|
|
|
|
+ host.Errors++
|
|
|
|
+ host.Updated = time.Now()
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ // 发送验证信息
|
|
|
|
+ if err := conn.WriteAuthInfo(h.channel, h.authFunc(host.Proto, host.Version, channel, nil)); err != nil {
|
|
|
|
+ log.Println("[WriteAuthInfo ERROR]", err)
|
|
|
|
+ conn.Close()
|
|
|
|
+ host.Errors++
|
|
|
|
+ host.Updated = time.Now()
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ // 接收频道信息
|
|
|
|
+ proto, version, channel2, _, err := conn.ReadAuthInfo()
|
|
|
|
+ if err != nil {
|
|
|
|
+ log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
|
|
|
|
+ conn.Close()
|
|
|
|
+ host.Errors++
|
|
|
|
+ host.Updated = time.Now()
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ // 检查版本和协议是否一致
|
|
|
|
+ if version != host.Version || proto != host.Proto {
|
|
|
|
+ err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
|
|
|
|
+ log.Println(err)
|
|
|
|
+ conn.Close()
|
|
|
|
+ host.Errors++
|
|
|
|
+ host.Updated = time.Now()
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ // 检查频道名称是否匹配
|
|
|
|
+ if !strings.Contains(channel2, channel) {
|
|
|
|
+ err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
|
|
|
|
+ log.Println(err)
|
|
|
|
+ conn.Close()
|
|
|
|
+ host.Errors++
|
|
|
|
+ host.Updated = time.Now()
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ // 更新服务主机信息
|
|
|
|
+ host.Errors = 0
|
|
|
|
+ host.Updated = time.Now()
|
|
|
|
+
|
|
|
|
+ // 将连接加入现有连接中
|
|
|
|
+ done := false
|
|
|
|
+ h.connects.Range(func(key, _ any) bool {
|
|
|
|
+ line := key.(*Line)
|
|
|
|
+ if line.channel == channel {
|
|
|
|
+ if line.state == Connected {
|
|
|
|
+ if !force {
|
|
|
|
+ err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
|
|
|
|
+ log.Println(err)
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+ line.Start(conn, host)
|
|
|
|
+ done = true
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+ // 新建一个连接
|
|
|
|
+ if !done {
|
|
|
|
+ line := NewConnect(h.cf, h, channel, conn, host)
|
|
|
|
+ h.addLine(line)
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 重试方式连接服务
|
|
|
|
+// 将会一直阻塞直到连接成功
|
|
|
|
+func (h *Hub) ConnectToServerX(channel string, force bool) {
|
|
|
|
+ for {
|
|
|
|
+ err := h.ConnectToServer(channel, force)
|
|
|
|
+ if err == nil {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ log.Println("[ConnectToServer ERROR, try it again]", err)
|
|
|
|
+ // 产生一个随机数避免刹间重连过载
|
|
|
|
+ r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
|
|
+ time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 建立一个集线器
|
|
|
|
+// connectFunc 用于监听连接状态的函数,可以为nil
|
|
|
|
+func NewHub(
|
|
|
|
+ cf *config.Config,
|
|
|
|
+ channel string,
|
|
|
|
+ // 客户端需要用的函数
|
|
|
|
+ connectHostFunc ConnectHostFunc,
|
|
|
|
+ authFunc AuthFunc,
|
|
|
|
+ // 服务端需要用的函数
|
|
|
|
+ checkAuthFunc CheckAuthFunc,
|
|
|
|
+ // 连接状态变化时调用的函数
|
|
|
|
+ connectStatusFunc ConnectStatusFunc,
|
|
|
|
+) (h *Hub) {
|
|
|
|
+ h = &Hub{
|
|
|
|
+ cf: cf,
|
|
|
|
+ channel: channel,
|
|
|
|
+ connectHostFunc: connectHostFunc,
|
|
|
|
+ authFunc: authFunc,
|
|
|
|
+ checkAuthFunc: checkAuthFunc,
|
|
|
|
+ connectStatusFunc: connectStatusFunc,
|
|
|
|
+ lastCleanDeadConnect: time.Now().UnixMilli(),
|
|
|
|
+ }
|
|
|
|
+ return h
|
|
|
|
+}
|