hub.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835
  1. package tinymq
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. "regexp"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. "git.me9.top/git/tinymq/config"
  16. "git.me9.top/git/tinymq/conn"
  17. "git.me9.top/git/tinymq/conn/tcp2"
  18. "git.me9.top/git/tinymq/conn/ws2"
  19. )
  20. // 类似一个插座的功能,管理多个连接
  21. // 一个hub即可以是客户端,同时也可以是服务端
  22. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  23. // 截取部分字符串
  24. func subStr(str string, length int) string {
  25. if len(str) <= length {
  26. return str
  27. }
  28. return str[0:length] + "..."
  29. }
  30. type Hub struct {
  31. sync.Mutex
  32. cf *config.Config
  33. globalID uint16
  34. channel string // 本地频道信息
  35. middle []MiddleFunc // 中间件
  36. connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
  37. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  38. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  39. // 客户端需要用的函数
  40. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  41. authFunc AuthFunc // 获取认证信息,用于发送给对方
  42. // 服务端需要用的函数
  43. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  44. // 连接状态变化时调用的函数
  45. connectStatusFunc ConnectStatusFunc
  46. // 上次清理异常连接时间戳
  47. lastCleanDeadConnect int64
  48. }
  49. // 清理异常连接
  50. func (h *Hub) cleanDeadConnect() {
  51. h.Lock()
  52. defer h.Unlock()
  53. now := time.Now().UnixMilli()
  54. if now-h.lastCleanDeadConnect > int64(h.cf.CleanDeadConnectWait) {
  55. h.lastCleanDeadConnect = now
  56. h.connects.Range(func(key, _ any) bool {
  57. line := key.(*Line)
  58. if line.state != Connected && now-line.updated.UnixMilli() > int64(h.cf.CleanDeadConnectWait) {
  59. h.connects.Delete(key)
  60. }
  61. return true
  62. })
  63. }
  64. }
  65. // 获取通讯消息ID号
  66. func (h *Hub) GetID() uint16 {
  67. h.Lock()
  68. defer h.Unlock()
  69. h.globalID++
  70. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  71. h.globalID = 1
  72. }
  73. for {
  74. // 检查是否在请求队列中存在对应的id
  75. if _, ok := h.msgCache.Load(h.globalID); ok {
  76. h.globalID++
  77. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  78. h.globalID = 1
  79. }
  80. } else {
  81. break
  82. }
  83. }
  84. return h.globalID
  85. }
  86. // 添加中间件
  87. // 如果中间件函数返回为空,表示处理完成,通过
  88. // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过
  89. func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
  90. h.middle = append(h.middle, middleFunc)
  91. }
  92. // 注册频道,其中频道为正则表达式字符串
  93. func (h *Hub) Subscribe(channel *regexp.Regexp, cmd string, backFunc SubscribeBack) (err error) {
  94. if channel == nil {
  95. return errors.New("channel can not be nil")
  96. }
  97. reg := &SubscribeData{
  98. Channel: channel,
  99. Cmd: cmd,
  100. BackFunc: backFunc,
  101. }
  102. sub, ok := h.subscribes.Load(cmd)
  103. if ok {
  104. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  105. return
  106. }
  107. regs := make([]*SubscribeData, 1)
  108. regs[0] = reg
  109. h.subscribes.Store(cmd, regs)
  110. return
  111. }
  112. // 遍历频道列表
  113. // 如果 fn 返回 false,则 range 停止迭代
  114. func (h *Hub) ConnectRange(fn func(line *Line) bool) {
  115. h.connects.Range(func(key, _ any) bool {
  116. line := key.(*Line)
  117. return fn(line)
  118. })
  119. }
  120. // 获取当前在线的数量
  121. func (h *Hub) ConnectNum() int {
  122. var count int
  123. h.connects.Range(func(key, _ any) bool {
  124. if key.(*Line).state == Connected {
  125. count++
  126. }
  127. return true
  128. })
  129. return count
  130. }
  131. // 获取所有的在线连接频道
  132. func (h *Hub) AllChannel() []string {
  133. cs := make([]string, 0)
  134. h.connects.Range(func(key, _ any) bool {
  135. line := key.(*Line)
  136. if line.state == Connected {
  137. cs = append(cs, line.channel)
  138. }
  139. return true
  140. })
  141. return cs
  142. }
  143. // 获取所有连接频道和连接时长
  144. // 为了避免定义数据结构麻烦,采用|隔开
  145. func (h *Hub) AllChannelTime() []string {
  146. cs := make([]string, 0)
  147. h.connects.Range(func(key, value any) bool {
  148. line := key.(*Line)
  149. if line.state == Connected {
  150. ti := time.Since(value.(time.Time)).Milliseconds()
  151. cs = append(cs, line.channel+"|"+strconv.FormatInt(ti, 10))
  152. }
  153. return true
  154. })
  155. return cs
  156. }
  157. // 获取频道并通过函数过滤,如果返回 false 将终止
  158. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  159. h.connects.Range(func(key, _ any) bool {
  160. line := key.(*Line)
  161. if line.state == Connected {
  162. return fn(line.channel)
  163. }
  164. return true
  165. })
  166. }
  167. // 从 channel 获取连接
  168. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  169. h.connects.Range(func(key, _ any) bool {
  170. l := key.(*Line)
  171. if l.channel == channel {
  172. line = l
  173. return false
  174. }
  175. return true
  176. })
  177. return
  178. }
  179. // 返回请求结果
  180. func (h *Hub) outResponse(response *ResponseData) {
  181. defer recover() //避免管道已经关闭而引起panic
  182. id := response.Id
  183. t, ok := h.msgCache.Load(id)
  184. if ok {
  185. // 删除数据缓存
  186. h.msgCache.Delete(id)
  187. gm := t.(*GetMsg)
  188. // 停止定时器
  189. if !gm.timer.Stop() {
  190. select {
  191. case <-gm.timer.C:
  192. default:
  193. }
  194. }
  195. // 回应数据到上层
  196. gm.out <- response
  197. }
  198. }
  199. // 发送数据到网络接口
  200. // 返回发送的数量
  201. func (h *Hub) sendRequest(gd *GetData) (count int) {
  202. h.connects.Range(func(key, _ any) bool {
  203. conn := key.(*Line)
  204. // 检查连接是否OK
  205. if conn.state != Connected {
  206. return true
  207. }
  208. if gd.Channel.MatchString(conn.channel) {
  209. var id uint16
  210. if gd.backchan != nil {
  211. id = h.GetID()
  212. timeout := gd.Timeout
  213. if timeout <= 0 {
  214. timeout = h.cf.WriteWait
  215. }
  216. fn := func(id uint16, conn *Line) func() {
  217. return func() {
  218. go h.outResponse(&ResponseData{
  219. Id: id,
  220. State: config.GET_TIMEOUT,
  221. Data: fmt.Appendf(nil, "[%s] %s %s", config.GET_TIMEOUT_MSG, gd.Channel.String(), gd.Cmd),
  222. conn: conn,
  223. })
  224. // 检查是否已经很久时间没有使用连接了
  225. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) {
  226. // 超时关闭当前的连接
  227. log.Println("get message timeout", conn.channel)
  228. // 有可能连接出现问题,断开并重新连接
  229. conn.Close(false)
  230. return
  231. }
  232. }
  233. }(id, conn)
  234. // 将要发送的请求缓存
  235. gm := &GetMsg{
  236. out: gd.backchan,
  237. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  238. }
  239. h.msgCache.Store(id, gm)
  240. }
  241. // 组织数据并发送到Connect
  242. conn.sendRequest <- &RequestData{
  243. Id: id,
  244. Cmd: gd.Cmd,
  245. Data: gd.Data,
  246. timeout: gd.Timeout,
  247. backchan: gd.backchan,
  248. conn: conn,
  249. }
  250. if h.cf.PrintMsg {
  251. log.Println("[SEND]->", id, conn.channel, "["+gd.Cmd+"]", subStr(string(gd.Data), 200))
  252. }
  253. count++
  254. if gd.Max > 0 && count >= gd.Max {
  255. return false
  256. }
  257. }
  258. return true
  259. })
  260. return
  261. }
  262. // 执行网络发送过来的命令
  263. func (h *Hub) requestFromNet(request *RequestData) {
  264. cmd := request.Cmd
  265. channel := request.conn.channel
  266. if h.cf.PrintMsg {
  267. log.Println("[REQU]<-", request.Id, channel, "["+cmd+"]", subStr(string(request.Data), 200))
  268. }
  269. // 执行中间件
  270. for _, mdFunc := range h.middle {
  271. rsp := mdFunc(request)
  272. if rsp != nil {
  273. // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理
  274. if rsp.State == config.NEXT_MIDDLE {
  275. continue
  276. }
  277. // 返回消息
  278. if request.Id != 0 {
  279. rsp.Id = request.Id
  280. request.conn.sendResponse <- rsp
  281. }
  282. return
  283. } else {
  284. break
  285. }
  286. }
  287. sub, ok := h.subscribes.Load(cmd)
  288. if ok {
  289. subs := sub.([]*SubscribeData)
  290. // 倒序查找是为了新增的频道响应函数优先执行
  291. for i := len(subs) - 1; i >= 0; i-- {
  292. rg := subs[i]
  293. if rg.Channel.MatchString(channel) {
  294. state, data := rg.BackFunc(request)
  295. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  296. if state == config.NEXT_SUBSCRIBE {
  297. continue
  298. }
  299. var byteData []byte
  300. switch data := data.(type) {
  301. case []byte:
  302. byteData = data
  303. case string:
  304. byteData = []byte(data)
  305. default:
  306. if data != nil {
  307. // 自动转换数据为json格式
  308. var err error
  309. byteData, err = json.Marshal(data)
  310. if err != nil {
  311. log.Println(err.Error())
  312. state = config.CONVERT_FAILED
  313. byteData = fmt.Appendf(nil, "[%s] %s %s", config.CONVERT_FAILED_MSG, request.conn.channel, request.Cmd)
  314. }
  315. }
  316. }
  317. // 如果id为0表示不需要回应
  318. if request.Id != 0 {
  319. request.conn.sendResponse <- &ResponseData{
  320. Id: request.Id,
  321. State: state,
  322. Data: byteData,
  323. }
  324. if h.cf.PrintMsg {
  325. log.Println("[RESP]->", request.Id, channel, "["+cmd+"]", state, subStr(string(byteData), 200))
  326. }
  327. }
  328. return
  329. }
  330. }
  331. }
  332. log.Println("[not match command]", channel, cmd)
  333. // 返回没有匹配的消息
  334. request.conn.sendResponse <- &ResponseData{
  335. Id: request.Id,
  336. State: config.NO_MATCH,
  337. Data: fmt.Appendf(nil, "[%s] %s %s", config.NO_MATCH_MSG, channel, cmd),
  338. }
  339. }
  340. // 请求频道并获取数据,采用回调的方式返回结果
  341. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  342. // 如果 backFunc 返回为 false 则提前结束
  343. // 最大数量和超时时间如果为0的话表示使用默认值
  344. func (h *Hub) GetWithMaxAndTimeout(channel *regexp.Regexp, cmd string, data any, backFunc GetBack, max int, timeout int) (count int) {
  345. var reqData []byte
  346. switch data := data.(type) {
  347. case []byte:
  348. reqData = data
  349. case string:
  350. reqData = []byte(data)
  351. default:
  352. if data != nil {
  353. // 自动转换数据为json格式
  354. var err error
  355. reqData, err = json.Marshal(data)
  356. if err != nil {
  357. log.Println(err.Error())
  358. return 0
  359. }
  360. }
  361. }
  362. // 排除空频道
  363. if channel == nil {
  364. return 0
  365. }
  366. if timeout <= 0 {
  367. timeout = h.cf.ReadWait
  368. }
  369. gd := &GetData{
  370. Channel: channel,
  371. Cmd: cmd,
  372. Data: reqData,
  373. Max: max,
  374. Timeout: timeout,
  375. backchan: make(chan *ResponseData, 32),
  376. }
  377. sendMax := h.sendRequest(gd)
  378. if sendMax <= 0 {
  379. return 0
  380. }
  381. // 避免出现异常时线程无法退出
  382. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  383. defer func() {
  384. if !timer.Stop() {
  385. select {
  386. case <-timer.C:
  387. default:
  388. }
  389. }
  390. close(gd.backchan)
  391. }()
  392. for {
  393. select {
  394. case rp := <-gd.backchan:
  395. if rp == nil || rp.conn == nil {
  396. // 可能是已经退出了
  397. return
  398. }
  399. ch := rp.conn.channel
  400. if h.cf.PrintMsg {
  401. log.Println("[RECV]<-", rp.Id, ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  402. }
  403. count++
  404. // 如果这里返回为false这跳出循环
  405. if backFunc != nil && !backFunc(rp) {
  406. return
  407. }
  408. if count >= sendMax {
  409. return
  410. }
  411. case <-timer.C:
  412. return
  413. }
  414. }
  415. // return
  416. }
  417. // 请求频道并获取数据,采用回调的方式返回结果
  418. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  419. // 如果 backFunc 返回为 false 则提前结束
  420. func (h *Hub) Get(channel *regexp.Regexp, cmd string, data any, backFunc GetBack) (count int) {
  421. return h.GetWithMaxAndTimeout(channel, cmd, data, backFunc, 0, 0)
  422. }
  423. // 只获取一个频道的数据,阻塞等待到默认超时间隔
  424. // 如果没有结果将返回 NO_MATCH
  425. func (h *Hub) GetOne(channel *regexp.Regexp, cmd string, data any) (response *ResponseData) {
  426. h.GetWithMaxAndTimeout(channel, cmd, data, func(rp *ResponseData) (ok bool) {
  427. response = rp
  428. return false
  429. }, 1, 0)
  430. if response == nil {
  431. response = &ResponseData{
  432. State: config.CONNECT_NO_MATCH,
  433. Data: fmt.Appendf(nil, "[%s] %s %s", config.CONNECT_NO_MATCH_MSG, channel.String(), cmd),
  434. }
  435. }
  436. return
  437. }
  438. // 只获取一个频道的数据,阻塞等待到指定超时间隔
  439. // 如果没有结果将返回 NO_MATCH
  440. func (h *Hub) GetOneWithTimeout(channel *regexp.Regexp, cmd string, data any, timeout int) (response *ResponseData) {
  441. h.GetWithMaxAndTimeout(channel, cmd, data, func(rp *ResponseData) (ok bool) {
  442. response = rp
  443. return false
  444. }, 1, timeout)
  445. if response == nil {
  446. response = &ResponseData{
  447. State: config.CONNECT_NO_MATCH,
  448. Data: fmt.Appendf(nil, "[%s] %s %s", config.CONNECT_NO_MATCH_MSG, channel.String(), cmd),
  449. }
  450. }
  451. return
  452. }
  453. // 推送消息出去,不需要返回数据
  454. func (h *Hub) Push(channel *regexp.Regexp, cmd string, data any) {
  455. h.PushWithMax(channel, cmd, data, 0)
  456. }
  457. // 推送最大对应数量的消息出去,不需要返回数据
  458. func (h *Hub) PushWithMax(channel *regexp.Regexp, cmd string, data any, max int) {
  459. // 排除空频道
  460. if channel == nil {
  461. return
  462. }
  463. var reqData []byte
  464. switch data := data.(type) {
  465. case []byte:
  466. reqData = data
  467. case string:
  468. reqData = []byte(data)
  469. default:
  470. if data != nil {
  471. // 自动转换数据为json格式
  472. var err error
  473. reqData, err = json.Marshal(data)
  474. if err != nil {
  475. log.Println(err.Error())
  476. return
  477. }
  478. }
  479. }
  480. gd := &GetData{
  481. Channel: channel,
  482. Cmd: cmd,
  483. Data: reqData,
  484. Max: max,
  485. Timeout: h.cf.ReadWait,
  486. backchan: nil,
  487. }
  488. h.sendRequest(gd)
  489. }
  490. // 增加连接
  491. func (h *Hub) addLine(line *Line) {
  492. if _, ok := h.connects.Load(line); ok {
  493. log.Println("connect have exist")
  494. // 连接已经存在,直接返回
  495. return
  496. }
  497. // 检查是否有相同的channel,如果有的话将其关闭删除
  498. channel := line.channel
  499. h.connects.Range(func(key, _ any) bool {
  500. conn := key.(*Line)
  501. // 删除超时的连接
  502. if conn.state != Connected && conn.host == nil && time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*5*int(time.Millisecond)) {
  503. h.connects.Delete(key)
  504. return true
  505. }
  506. if conn.channel == channel {
  507. conn.Close(true)
  508. h.connects.Delete(key)
  509. return false
  510. }
  511. return true
  512. })
  513. h.connects.Store(line, true)
  514. }
  515. // 删除连接
  516. func (h *Hub) removeLine(conn *Line) {
  517. conn.Close(true)
  518. h.connects.Delete(conn)
  519. }
  520. // 获取指定连接的连接持续时间
  521. func (h *Hub) ConnectDuration(conn *Line) time.Duration {
  522. t, ok := h.connects.Load(conn)
  523. if ok {
  524. return time.Since(t.(time.Time))
  525. }
  526. // 如果不存在直接返回0
  527. return time.Duration(0)
  528. }
  529. // 绑定端口,建立服务
  530. // 需要程序运行时调用
  531. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  532. doConnectFunc := func(conn conn.Connect) {
  533. proto, version, channel, auth, err := conn.ReadAuthInfo()
  534. if err != nil {
  535. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  536. conn.Close()
  537. return
  538. }
  539. if version != info.Version || proto != info.Proto {
  540. log.Println("wrong version or protocol: ", version, proto)
  541. conn.Close()
  542. return
  543. }
  544. // 检查验证是否合法
  545. if !h.checkAuthFunc(proto, version, channel, auth) {
  546. conn.Close()
  547. return
  548. }
  549. // 发送频道信息
  550. if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
  551. log.Println("[WriteAuthInfo ERROR]", err)
  552. conn.Close()
  553. return
  554. }
  555. // 将连接加入现有连接中
  556. done := false
  557. h.connects.Range(func(key, _ any) bool {
  558. line := key.(*Line)
  559. if line.state == Disconnected && line.channel == channel && line.host == nil {
  560. line.Start(conn, nil)
  561. done = true
  562. return false
  563. }
  564. return true
  565. })
  566. // 新建一个连接
  567. if !done {
  568. line := NewConnect(h.cf, h, channel, conn, nil)
  569. h.addLine(line)
  570. }
  571. }
  572. if info.Version == ws2.VERSION && info.Proto == ws2.PROTO {
  573. bind := ""
  574. if info.Bind != "" {
  575. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  576. }
  577. return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  578. } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO {
  579. return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  580. }
  581. return errors.New("not connect protocol and version found")
  582. }
  583. // 新建一个连接,不同的连接协议由底层自己选择
  584. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  585. func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err error) {
  586. // 检查当前channel是否已经存在
  587. if !force {
  588. line := h.ChannelToLine(channel)
  589. if line != nil && line.state == Connected {
  590. err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  591. return
  592. }
  593. }
  594. if host == nil {
  595. // 获取服务地址等信息
  596. host, err = h.connectHostFunc(channel, Both)
  597. if err != nil {
  598. return err
  599. }
  600. }
  601. var conn conn.Connect
  602. var runProto string
  603. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  604. // 添加定时器
  605. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
  606. defer cancel()
  607. taskCh := make(chan bool)
  608. done := false
  609. go func() {
  610. if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
  611. runProto = ws2.PROTO
  612. conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
  613. } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
  614. runProto = tcp2.PROTO
  615. conn, err = tcp2.Dial(h.cf, addr, host.Hash)
  616. } else {
  617. err = fmt.Errorf("not correct protocol and version found in: %+v", host)
  618. }
  619. if done {
  620. if err != nil {
  621. log.Println("[Dial ERROR]", err)
  622. }
  623. if conn != nil {
  624. conn.Close()
  625. }
  626. } else {
  627. taskCh <- err == nil
  628. }
  629. }()
  630. select {
  631. case ok := <-taskCh:
  632. cancel()
  633. if !ok || err != nil || conn == nil {
  634. log.Println("[Client ERROR]", host.Proto, err)
  635. host.Errors++
  636. host.Updated = time.Now()
  637. if err == nil {
  638. err = errors.New("unknown error")
  639. }
  640. return err
  641. }
  642. case <-ctx.Done():
  643. done = true
  644. return errors.New("timeout")
  645. }
  646. // 发送验证信息
  647. if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
  648. log.Println("[WriteAuthInfo ERROR]", err)
  649. conn.Close()
  650. host.Errors++
  651. host.Updated = time.Now()
  652. return err
  653. }
  654. // 接收频道信息
  655. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  656. if err != nil {
  657. log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
  658. conn.Close()
  659. host.Errors++
  660. host.Updated = time.Now()
  661. return err
  662. }
  663. // 检查版本和协议是否一致
  664. if version != host.Version || proto != runProto {
  665. err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
  666. log.Println(err)
  667. conn.Close()
  668. host.Errors++
  669. host.Updated = time.Now()
  670. return err
  671. }
  672. // 检查频道名称是否匹配
  673. if !strings.Contains(channel2, channel) {
  674. err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
  675. log.Println(err)
  676. conn.Close()
  677. host.Errors++
  678. host.Updated = time.Now()
  679. return err
  680. }
  681. // 检查验证是否合法
  682. if !h.checkAuthFunc(proto, version, channel, auth) {
  683. err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  684. log.Println(err)
  685. conn.Close()
  686. host.Errors++
  687. host.Updated = time.Now()
  688. return err
  689. }
  690. // 更新服务主机信息
  691. host.Errors = 0
  692. host.Updated = time.Now()
  693. // 将连接加入现有连接中
  694. done = false
  695. h.connects.Range(func(key, _ any) bool {
  696. line := key.(*Line)
  697. if line.channel == channel {
  698. if line.state == Connected {
  699. if force {
  700. line.Close(true)
  701. } else {
  702. err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
  703. log.Println(err)
  704. return false
  705. }
  706. }
  707. line.Start(conn, host)
  708. done = true
  709. return false
  710. }
  711. return true
  712. })
  713. if err != nil {
  714. return err
  715. }
  716. // 新建一个连接
  717. if !done {
  718. line := NewConnect(h.cf, h, channel, conn, host)
  719. h.addLine(line)
  720. }
  721. return nil
  722. }
  723. // 重试方式连接服务
  724. // 将会一直阻塞直到连接成功
  725. func (h *Hub) ConnectToServerX(channel string, force bool) {
  726. host, _ := h.connectHostFunc(channel, Direct)
  727. for {
  728. err := h.ConnectToServer(channel, force, host)
  729. if err == nil {
  730. return
  731. }
  732. log.Println("[ConnectToServer ERROR, try it again]", channel, host, err)
  733. host = nil
  734. // 产生一个随机数避免刹间重连过载
  735. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  736. time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
  737. }
  738. }
  739. // 检测处理代理连接
  740. func (h *Hub) checkProxyConnect() {
  741. if h.cf.ProxyTimeout <= 0 {
  742. return
  743. }
  744. proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond)))
  745. for {
  746. <-proxyTicker.C
  747. now := time.Now().UnixMilli()
  748. h.connects.Range(func(key, _ any) bool {
  749. line := key.(*Line)
  750. if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) {
  751. host, err := h.connectHostFunc(line.channel, Direct)
  752. if err != nil {
  753. log.Println("[checkProxyConnect connectHostFunc ERROR]", err)
  754. return false
  755. }
  756. err = h.ConnectToServer(line.channel, true, host)
  757. if err != nil {
  758. log.Println("[checkProxyConnect ConnectToServer WARNING]", err)
  759. }
  760. }
  761. return true
  762. })
  763. }
  764. }
  765. // 建立一个集线器
  766. // connectFunc 用于监听连接状态的函数,可以为nil
  767. func NewHub(
  768. cf *config.Config,
  769. channel string,
  770. // 客户端需要用的函数
  771. connectHostFunc ConnectHostFunc,
  772. authFunc AuthFunc,
  773. // 服务端需要用的函数
  774. checkAuthFunc CheckAuthFunc,
  775. // 连接状态变化时调用的函数
  776. connectStatusFunc ConnectStatusFunc,
  777. ) (h *Hub) {
  778. h = &Hub{
  779. cf: cf,
  780. globalID: uint16(time.Now().UnixNano()) % config.ID_MAX,
  781. channel: channel,
  782. middle: make([]MiddleFunc, 0),
  783. connectHostFunc: connectHostFunc,
  784. authFunc: authFunc,
  785. checkAuthFunc: checkAuthFunc,
  786. connectStatusFunc: connectStatusFunc,
  787. lastCleanDeadConnect: time.Now().UnixMilli(),
  788. }
  789. go h.checkProxyConnect()
  790. return h
  791. }