hub.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791
  1. package tinymq
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "math/rand"
  8. "net"
  9. "regexp"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. "time"
  14. "git.me9.top/git/tinymq/config"
  15. "git.me9.top/git/tinymq/conn"
  16. "git.me9.top/git/tinymq/conn/tcp2"
  17. "git.me9.top/git/tinymq/conn/ws2"
  18. )
  19. // 类似一个插座的功能,管理多个连接
  20. // 一个hub即可以是客户端,同时也可以是服务端
  21. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  22. // 截取部分字符串
  23. func subStr(str string, length int) string {
  24. if len(str) <= length {
  25. return str
  26. }
  27. return str[0:length] + "..."
  28. }
  29. type Hub struct {
  30. sync.Mutex
  31. cf *config.Config
  32. globalID uint16
  33. channel string // 本地频道信息
  34. middle []MiddleFunc // 中间件
  35. connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
  36. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  37. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  38. // 客户端需要用的函数
  39. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  40. authFunc AuthFunc // 获取认证信息,用于发送给对方
  41. // 服务端需要用的函数
  42. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  43. // 连接状态变化时调用的函数
  44. connectStatusFunc ConnectStatusFunc
  45. // 上次清理异常连接时间戳
  46. lastCleanDeadConnect int64
  47. }
  48. // 清理异常连接
  49. func (h *Hub) cleanDeadConnect() {
  50. h.Lock()
  51. defer h.Unlock()
  52. now := time.Now().UnixMilli()
  53. if now-h.lastCleanDeadConnect > int64(h.cf.CleanDeadConnectWait) {
  54. h.lastCleanDeadConnect = now
  55. h.connects.Range(func(key, _ any) bool {
  56. line := key.(*Line)
  57. if line.state != Connected && now-line.updated.UnixMilli() > int64(h.cf.CleanDeadConnectWait) {
  58. h.connects.Delete(key)
  59. }
  60. return true
  61. })
  62. }
  63. }
  64. // 获取通讯消息ID号
  65. func (h *Hub) GetID() uint16 {
  66. h.Lock()
  67. defer h.Unlock()
  68. h.globalID++
  69. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  70. h.globalID = 1
  71. }
  72. for {
  73. // 检查是否在请求队列中存在对应的id
  74. if _, ok := h.msgCache.Load(h.globalID); ok {
  75. h.globalID++
  76. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  77. h.globalID = 1
  78. }
  79. } else {
  80. break
  81. }
  82. }
  83. return h.globalID
  84. }
  85. // 添加中间件
  86. // 如果中间件函数返回为空,表示处理完成,通过
  87. // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过
  88. func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
  89. h.middle = append(h.middle, middleFunc)
  90. }
  91. // 注册频道,其中频道为正则表达式字符串
  92. func (h *Hub) Subscribe(channel *regexp.Regexp, cmd string, backFunc SubscribeBack) (err error) {
  93. if channel == nil {
  94. return errors.New("channel can not be nil")
  95. }
  96. reg := &SubscribeData{
  97. Channel: channel,
  98. Cmd: cmd,
  99. BackFunc: backFunc,
  100. }
  101. sub, ok := h.subscribes.Load(cmd)
  102. if ok {
  103. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  104. return
  105. }
  106. regs := make([]*SubscribeData, 1)
  107. regs[0] = reg
  108. h.subscribes.Store(cmd, regs)
  109. return
  110. }
  111. // 遍历频道列表
  112. // 如果 fn 返回 false,则 range 停止迭代
  113. func (h *Hub) ConnectRange(fn func(line *Line) bool) {
  114. h.connects.Range(func(key, _ any) bool {
  115. line := key.(*Line)
  116. return fn(line)
  117. })
  118. }
  119. // 获取当前在线的数量
  120. func (h *Hub) ConnectNum() int {
  121. var count int
  122. h.connects.Range(func(key, _ any) bool {
  123. if key.(*Line).state == Connected {
  124. count++
  125. }
  126. return true
  127. })
  128. return count
  129. }
  130. // 获取所有的在线连接频道
  131. func (h *Hub) AllChannel() []string {
  132. cs := make([]string, 0)
  133. h.connects.Range(func(key, _ any) bool {
  134. line := key.(*Line)
  135. if line.state == Connected {
  136. cs = append(cs, line.channel)
  137. }
  138. return true
  139. })
  140. return cs
  141. }
  142. // 获取所有连接频道和连接时长
  143. // 为了避免定义数据结构麻烦,采用|隔开
  144. func (h *Hub) AllChannelTime() []string {
  145. cs := make([]string, 0)
  146. h.connects.Range(func(key, value any) bool {
  147. line := key.(*Line)
  148. if line.state == Connected {
  149. ti := time.Since(value.(time.Time)).Milliseconds()
  150. cs = append(cs, line.channel+"|"+strconv.FormatInt(ti, 10))
  151. }
  152. return true
  153. })
  154. return cs
  155. }
  156. // 获取频道并通过函数过滤,如果返回 false 将终止
  157. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  158. h.connects.Range(func(key, _ any) bool {
  159. line := key.(*Line)
  160. if line.state == Connected {
  161. return fn(line.channel)
  162. }
  163. return true
  164. })
  165. }
  166. // 从 channel 获取连接
  167. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  168. h.connects.Range(func(key, _ any) bool {
  169. l := key.(*Line)
  170. if l.channel == channel {
  171. line = l
  172. return false
  173. }
  174. return true
  175. })
  176. return
  177. }
  178. // 返回请求结果
  179. func (h *Hub) outResponse(response *ResponseData) {
  180. defer recover() //避免管道已经关闭而引起panic
  181. id := response.Id
  182. t, ok := h.msgCache.Load(id)
  183. if ok {
  184. // 删除数据缓存
  185. h.msgCache.Delete(id)
  186. gm := t.(*GetMsg)
  187. // 停止定时器
  188. if !gm.timer.Stop() {
  189. select {
  190. case <-gm.timer.C:
  191. default:
  192. }
  193. }
  194. // 回应数据到上层
  195. gm.out <- response
  196. }
  197. }
  198. // 发送数据到网络接口
  199. // 返回发送的数量
  200. func (h *Hub) sendRequest(gd *GetData) (count int) {
  201. h.connects.Range(func(key, _ any) bool {
  202. conn := key.(*Line)
  203. // 检查连接是否OK
  204. if conn.state != Connected {
  205. return true
  206. }
  207. if gd.Channel.MatchString(conn.channel) {
  208. var id uint16
  209. if gd.backchan != nil {
  210. id = h.GetID()
  211. timeout := gd.Timeout
  212. if timeout <= 0 {
  213. timeout = h.cf.WriteWait
  214. }
  215. fn := func(id uint16, conn *Line) func() {
  216. return func() {
  217. go h.outResponse(&ResponseData{
  218. Id: id,
  219. State: config.GET_TIMEOUT,
  220. Data: []byte(config.GET_TIMEOUT_MSG),
  221. conn: conn,
  222. })
  223. // 检查是否已经很久时间没有使用连接了
  224. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) {
  225. // 超时关闭当前的连接
  226. log.Println("get message timeout", conn.channel)
  227. // 有可能连接出现问题,断开并重新连接
  228. conn.Close(false)
  229. return
  230. }
  231. }
  232. }(id, conn)
  233. // 将要发送的请求缓存
  234. gm := &GetMsg{
  235. out: gd.backchan,
  236. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  237. }
  238. h.msgCache.Store(id, gm)
  239. }
  240. // 组织数据并发送到Connect
  241. conn.sendRequest <- &RequestData{
  242. Id: id,
  243. Cmd: gd.Cmd,
  244. Data: gd.Data,
  245. timeout: gd.Timeout,
  246. backchan: gd.backchan,
  247. conn: conn,
  248. }
  249. if h.cf.PrintMsg {
  250. log.Println("[SEND]->", conn.channel, "["+gd.Cmd+"]", subStr(string(gd.Data), 200))
  251. }
  252. count++
  253. if gd.Max > 0 && count >= gd.Max {
  254. return false
  255. }
  256. }
  257. return true
  258. })
  259. return
  260. }
  261. // 执行网络发送过来的命令
  262. func (h *Hub) requestFromNet(request *RequestData) {
  263. cmd := request.Cmd
  264. channel := request.conn.channel
  265. if h.cf.PrintMsg {
  266. log.Println("[REQU]<-", channel, "["+cmd+"]", subStr(string(request.Data), 200))
  267. }
  268. // 执行中间件
  269. for _, mdFunc := range h.middle {
  270. rsp := mdFunc(request)
  271. if rsp != nil {
  272. // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理
  273. if rsp.State == config.NEXT_MIDDLE {
  274. continue
  275. }
  276. // 返回消息
  277. if request.Id != 0 {
  278. rsp.Id = request.Id
  279. request.conn.sendResponse <- rsp
  280. }
  281. return
  282. } else {
  283. break
  284. }
  285. }
  286. sub, ok := h.subscribes.Load(cmd)
  287. if ok {
  288. subs := sub.([]*SubscribeData)
  289. // 倒序查找是为了新增的频道响应函数优先执行
  290. for i := len(subs) - 1; i >= 0; i-- {
  291. rg := subs[i]
  292. if rg.Channel.MatchString(channel) {
  293. state, data := rg.BackFunc(request)
  294. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  295. if state == config.NEXT_SUBSCRIBE {
  296. continue
  297. }
  298. // 如果id为0表示不需要回应
  299. if request.Id != 0 {
  300. request.conn.sendResponse <- &ResponseData{
  301. Id: request.Id,
  302. State: state,
  303. Data: data,
  304. }
  305. if h.cf.PrintMsg {
  306. log.Println("[RESP]->", channel, "["+cmd+"]", state, subStr(string(data), 200))
  307. }
  308. }
  309. return
  310. }
  311. }
  312. }
  313. log.Println("[not match command]", channel, cmd)
  314. // 返回没有匹配的消息
  315. request.conn.sendResponse <- &ResponseData{
  316. Id: request.Id,
  317. State: config.NO_MATCH,
  318. Data: []byte(config.NO_MATCH_MSG),
  319. }
  320. }
  321. // 请求频道并获取数据,采用回调的方式返回结果
  322. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  323. // 如果 backFunc 返回为 false 则提前结束
  324. // 最大数量和超时时间如果为0的话表示使用默认值
  325. func (h *Hub) GetWithMaxAndTimeout(channel *regexp.Regexp, cmd string, data []byte, backFunc GetBack, max int, timeout int) (count int) {
  326. // 排除空频道
  327. if channel == nil {
  328. return 0
  329. }
  330. if timeout <= 0 {
  331. timeout = h.cf.ReadWait
  332. }
  333. gd := &GetData{
  334. Channel: channel,
  335. Cmd: cmd,
  336. Data: data,
  337. Max: max,
  338. Timeout: timeout,
  339. backchan: make(chan *ResponseData, 32),
  340. }
  341. sendMax := h.sendRequest(gd)
  342. if sendMax <= 0 {
  343. return 0
  344. }
  345. // 避免出现异常时线程无法退出
  346. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  347. defer func() {
  348. if !timer.Stop() {
  349. select {
  350. case <-timer.C:
  351. default:
  352. }
  353. }
  354. close(gd.backchan)
  355. }()
  356. for {
  357. select {
  358. case rp := <-gd.backchan:
  359. if rp == nil || rp.conn == nil {
  360. // 可能是已经退出了
  361. return
  362. }
  363. ch := rp.conn.channel
  364. if h.cf.PrintMsg {
  365. log.Println("[RECV]<-", ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  366. }
  367. count++
  368. // 如果这里返回为false这跳出循环
  369. if backFunc != nil && !backFunc(rp) {
  370. return
  371. }
  372. if count >= sendMax {
  373. return
  374. }
  375. case <-timer.C:
  376. return
  377. }
  378. }
  379. // return
  380. }
  381. // 请求频道并获取数据,采用回调的方式返回结果
  382. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  383. // 如果 backFunc 返回为 false 则提前结束
  384. func (h *Hub) Get(channel *regexp.Regexp, cmd string, data []byte, backFunc GetBack) (count int) {
  385. return h.GetWithMaxAndTimeout(channel, cmd, data, backFunc, 0, 0)
  386. }
  387. // 只获取一个频道的数据,阻塞等待到默认超时间隔
  388. // 如果没有结果将返回 NO_MATCH
  389. func (h *Hub) GetOne(channel *regexp.Regexp, cmd string, data []byte) (response *ResponseData) {
  390. h.GetWithMaxAndTimeout(channel, cmd, data, func(rp *ResponseData) (ok bool) {
  391. response = rp
  392. return false
  393. }, 1, 0)
  394. if response == nil {
  395. response = &ResponseData{
  396. State: config.NO_MATCH,
  397. Data: []byte(config.NO_MATCH_MSG),
  398. }
  399. }
  400. return
  401. }
  402. // 只获取一个频道的数据,阻塞等待到指定超时间隔
  403. // 如果没有结果将返回 NO_MATCH
  404. func (h *Hub) GetOneWithTimeout(channel *regexp.Regexp, cmd string, data []byte, timeout int) (response *ResponseData) {
  405. h.GetWithMaxAndTimeout(channel, cmd, data, func(rp *ResponseData) (ok bool) {
  406. response = rp
  407. return false
  408. }, 1, timeout)
  409. if response == nil {
  410. response = &ResponseData{
  411. State: config.NO_MATCH,
  412. Data: []byte(config.NO_MATCH_MSG),
  413. }
  414. }
  415. return
  416. }
  417. // 推送消息出去,不需要返回数据
  418. func (h *Hub) Push(channel *regexp.Regexp, cmd string, data []byte) {
  419. // 排除空频道
  420. if channel == nil {
  421. return
  422. }
  423. gd := &GetData{
  424. Channel: channel,
  425. Cmd: cmd,
  426. Data: data,
  427. Timeout: h.cf.ReadWait,
  428. backchan: nil,
  429. }
  430. h.sendRequest(gd)
  431. }
  432. // 推送最大对应数量的消息出去,不需要返回数据
  433. func (h *Hub) PushWithMax(channel *regexp.Regexp, cmd string, data []byte, max int) {
  434. // 排除空频道
  435. if channel == nil {
  436. return
  437. }
  438. gd := &GetData{
  439. Channel: channel,
  440. Cmd: cmd,
  441. Data: data,
  442. Max: max,
  443. Timeout: h.cf.ReadWait,
  444. backchan: nil,
  445. }
  446. h.sendRequest(gd)
  447. }
  448. // 增加连接
  449. func (h *Hub) addLine(line *Line) {
  450. if _, ok := h.connects.Load(line); ok {
  451. log.Println("connect have exist")
  452. // 连接已经存在,直接返回
  453. return
  454. }
  455. // 检查是否有相同的channel,如果有的话将其关闭删除
  456. channel := line.channel
  457. h.connects.Range(func(key, _ any) bool {
  458. conn := key.(*Line)
  459. // 删除超时的连接
  460. if conn.state != Connected && conn.host == nil && time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*5*int(time.Millisecond)) {
  461. h.connects.Delete(key)
  462. return true
  463. }
  464. if conn.channel == channel {
  465. conn.Close(true)
  466. h.connects.Delete(key)
  467. return false
  468. }
  469. return true
  470. })
  471. h.connects.Store(line, true)
  472. }
  473. // 删除连接
  474. func (h *Hub) removeLine(conn *Line) {
  475. conn.Close(true)
  476. h.connects.Delete(conn)
  477. }
  478. // 获取指定连接的连接持续时间
  479. func (h *Hub) ConnectDuration(conn *Line) time.Duration {
  480. t, ok := h.connects.Load(conn)
  481. if ok {
  482. return time.Since(t.(time.Time))
  483. }
  484. // 如果不存在直接返回0
  485. return time.Duration(0)
  486. }
  487. // 绑定端口,建立服务
  488. // 需要程序运行时调用
  489. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  490. doConnectFunc := func(conn conn.Connect) {
  491. proto, version, channel, auth, err := conn.ReadAuthInfo()
  492. if err != nil {
  493. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  494. conn.Close()
  495. return
  496. }
  497. if version != info.Version || proto != info.Proto {
  498. log.Println("wrong version or protocol: ", version, proto)
  499. conn.Close()
  500. return
  501. }
  502. // 检查验证是否合法
  503. if !h.checkAuthFunc(proto, version, channel, auth) {
  504. conn.Close()
  505. return
  506. }
  507. // 发送频道信息
  508. if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
  509. log.Println("[WriteAuthInfo ERROR]", err)
  510. conn.Close()
  511. return
  512. }
  513. // 将连接加入现有连接中
  514. done := false
  515. h.connects.Range(func(key, _ any) bool {
  516. line := key.(*Line)
  517. if line.state == Disconnected && line.channel == channel && line.host == nil {
  518. line.Start(conn, nil)
  519. done = true
  520. return false
  521. }
  522. return true
  523. })
  524. // 新建一个连接
  525. if !done {
  526. line := NewConnect(h.cf, h, channel, conn, nil)
  527. h.addLine(line)
  528. }
  529. }
  530. if info.Version == ws2.VERSION && info.Proto == ws2.PROTO {
  531. bind := ""
  532. if info.Bind != "" {
  533. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  534. }
  535. return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  536. } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO {
  537. return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  538. }
  539. return errors.New("not connect protocol and version found")
  540. }
  541. // 新建一个连接,不同的连接协议由底层自己选择
  542. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  543. func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err error) {
  544. // 检查当前channel是否已经存在
  545. if !force {
  546. line := h.ChannelToLine(channel)
  547. if line != nil && line.state == Connected {
  548. err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  549. return
  550. }
  551. }
  552. if host == nil {
  553. // 获取服务地址等信息
  554. host, err = h.connectHostFunc(channel, Both)
  555. if err != nil {
  556. return err
  557. }
  558. }
  559. var conn conn.Connect
  560. var runProto string
  561. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  562. // 添加定时器
  563. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
  564. defer cancel()
  565. taskCh := make(chan bool)
  566. done := false
  567. go func() {
  568. if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
  569. runProto = ws2.PROTO
  570. conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
  571. } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
  572. runProto = tcp2.PROTO
  573. conn, err = tcp2.Dial(h.cf, addr, host.Hash)
  574. } else {
  575. err = fmt.Errorf("not correct protocol and version found in: %+v", host)
  576. }
  577. if done {
  578. if err != nil {
  579. log.Println("[Dial ERROR]", err)
  580. }
  581. if conn != nil {
  582. conn.Close()
  583. }
  584. } else {
  585. taskCh <- err == nil
  586. }
  587. }()
  588. select {
  589. case ok := <-taskCh:
  590. cancel()
  591. if !ok || err != nil || conn == nil {
  592. log.Println("[Client ERROR]", host.Proto, err)
  593. host.Errors++
  594. host.Updated = time.Now()
  595. if err == nil {
  596. err = errors.New("unknown error")
  597. }
  598. return err
  599. }
  600. case <-ctx.Done():
  601. done = true
  602. return errors.New("timeout")
  603. }
  604. // 发送验证信息
  605. if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
  606. log.Println("[WriteAuthInfo ERROR]", err)
  607. conn.Close()
  608. host.Errors++
  609. host.Updated = time.Now()
  610. return err
  611. }
  612. // 接收频道信息
  613. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  614. if err != nil {
  615. log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
  616. conn.Close()
  617. host.Errors++
  618. host.Updated = time.Now()
  619. return err
  620. }
  621. // 检查版本和协议是否一致
  622. if version != host.Version || proto != runProto {
  623. err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
  624. log.Println(err)
  625. conn.Close()
  626. host.Errors++
  627. host.Updated = time.Now()
  628. return err
  629. }
  630. // 检查频道名称是否匹配
  631. if !strings.Contains(channel2, channel) {
  632. err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
  633. log.Println(err)
  634. conn.Close()
  635. host.Errors++
  636. host.Updated = time.Now()
  637. return err
  638. }
  639. // 检查验证是否合法
  640. if !h.checkAuthFunc(proto, version, channel, auth) {
  641. err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  642. log.Println(err)
  643. conn.Close()
  644. host.Errors++
  645. host.Updated = time.Now()
  646. return err
  647. }
  648. // 更新服务主机信息
  649. host.Errors = 0
  650. host.Updated = time.Now()
  651. // 将连接加入现有连接中
  652. done = false
  653. h.connects.Range(func(key, _ any) bool {
  654. line := key.(*Line)
  655. if line.channel == channel {
  656. if line.state == Connected {
  657. if force {
  658. line.Close(true)
  659. } else {
  660. err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
  661. log.Println(err)
  662. return false
  663. }
  664. }
  665. line.Start(conn, host)
  666. done = true
  667. return false
  668. }
  669. return true
  670. })
  671. if err != nil {
  672. return err
  673. }
  674. // 新建一个连接
  675. if !done {
  676. line := NewConnect(h.cf, h, channel, conn, host)
  677. h.addLine(line)
  678. }
  679. return nil
  680. }
  681. // 重试方式连接服务
  682. // 将会一直阻塞直到连接成功
  683. func (h *Hub) ConnectToServerX(channel string, force bool) {
  684. host, _ := h.connectHostFunc(channel, Direct)
  685. for {
  686. err := h.ConnectToServer(channel, force, host)
  687. if err == nil {
  688. return
  689. }
  690. host = nil
  691. log.Println("[ConnectToServer ERROR, try it again]", err)
  692. // 产生一个随机数避免刹间重连过载
  693. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  694. time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
  695. }
  696. }
  697. // 检测处理代理连接
  698. func (h *Hub) checkProxyConnect() {
  699. if h.cf.ProxyTimeout <= 0 {
  700. return
  701. }
  702. proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond)))
  703. for {
  704. <-proxyTicker.C
  705. now := time.Now().UnixMilli()
  706. h.connects.Range(func(key, _ any) bool {
  707. line := key.(*Line)
  708. if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) {
  709. host, err := h.connectHostFunc(line.channel, Direct)
  710. if err != nil {
  711. log.Println("[checkProxyConnect connectHostFunc ERROR]", err)
  712. return false
  713. }
  714. err = h.ConnectToServer(line.channel, true, host)
  715. if err != nil {
  716. log.Println("[checkProxyConnect ConnectToServer WARNING]", err)
  717. }
  718. }
  719. return true
  720. })
  721. }
  722. }
  723. // 建立一个集线器
  724. // connectFunc 用于监听连接状态的函数,可以为nil
  725. func NewHub(
  726. cf *config.Config,
  727. channel string,
  728. // 客户端需要用的函数
  729. connectHostFunc ConnectHostFunc,
  730. authFunc AuthFunc,
  731. // 服务端需要用的函数
  732. checkAuthFunc CheckAuthFunc,
  733. // 连接状态变化时调用的函数
  734. connectStatusFunc ConnectStatusFunc,
  735. ) (h *Hub) {
  736. h = &Hub{
  737. cf: cf,
  738. channel: channel,
  739. middle: make([]MiddleFunc, 0),
  740. connectHostFunc: connectHostFunc,
  741. authFunc: authFunc,
  742. checkAuthFunc: checkAuthFunc,
  743. connectStatusFunc: connectStatusFunc,
  744. lastCleanDeadConnect: time.Now().UnixMilli(),
  745. }
  746. go h.checkProxyConnect()
  747. return h
  748. }