hub.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. package tinymq
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "math/rand"
  7. "net"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. "git.me9.top/git/tinymq/config"
  13. "git.me9.top/git/tinymq/conn"
  14. "git.me9.top/git/tinymq/conn/tpv2"
  15. "git.me9.top/git/tinymq/conn/wsv2"
  16. )
  17. // 类似一个插座的功能,管理多个连接
  18. // 一个hub即可以是客户端,同时也可以是服务端
  19. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  20. // 截取部分字符串
  21. func subStr(str string, length int) string {
  22. if len(str) <= length {
  23. return str
  24. }
  25. return str[0:length] + "..."
  26. }
  27. type Hub struct {
  28. sync.Mutex
  29. cf *config.Config
  30. globalID uint16
  31. channel string // 本地频道信息
  32. connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
  33. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  34. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  35. // 客户端需要用的函数
  36. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  37. authFunc AuthFunc // 获取认证信息,用于发送给对方
  38. // 服务端需要用的函数
  39. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  40. // 连接状态变化时调用的函数
  41. connectStatusFunc ConnectStatusFunc
  42. // 上次清理异常连接时间戳
  43. lastCleanDeadConnect int64
  44. }
  45. // 清理异常连接
  46. func (h *Hub) cleanDeadConnect() {
  47. h.Lock()
  48. defer h.Unlock()
  49. now := time.Now().UnixMilli()
  50. if now-h.lastCleanDeadConnect > int64(h.cf.CleanDeadConnectWait) {
  51. h.lastCleanDeadConnect = now
  52. h.connects.Range(func(key, _ any) bool {
  53. line := key.(*Line)
  54. if line.state != Connected && now-line.updated.UnixMilli() > int64(h.cf.CleanDeadConnectWait) {
  55. h.connects.Delete(key)
  56. }
  57. return true
  58. })
  59. }
  60. }
  61. // 获取通讯消息ID号
  62. func (h *Hub) GetID() uint16 {
  63. h.Lock()
  64. defer h.Unlock()
  65. h.globalID++
  66. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  67. h.globalID = 1
  68. }
  69. for {
  70. // 检查是否在请求队列中存在对应的id
  71. if _, ok := h.msgCache.Load(h.globalID); ok {
  72. h.globalID++
  73. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  74. h.globalID = 1
  75. }
  76. } else {
  77. break
  78. }
  79. }
  80. return h.globalID
  81. }
  82. // 注册频道,其中频道为正则表达式字符串
  83. func (h *Hub) Subscribe(reg *SubscribeData) (err error) {
  84. if reg.Channel == nil {
  85. return errors.New("channel can not be nil")
  86. }
  87. cmd := reg.Cmd
  88. sub, ok := h.subscribes.Load(cmd)
  89. if ok {
  90. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  91. return
  92. }
  93. regs := make([]*SubscribeData, 1)
  94. regs[0] = reg
  95. h.subscribes.Store(cmd, regs)
  96. return
  97. }
  98. // 获取当前在线的数量
  99. func (h *Hub) ConnectNum() int {
  100. var count int
  101. h.connects.Range(func(key, _ any) bool {
  102. if key.(*Line).state == Connected {
  103. count++
  104. }
  105. return true
  106. })
  107. return count
  108. }
  109. // 获取所有的在线连接频道
  110. func (h *Hub) AllChannel() []string {
  111. cs := make([]string, 0)
  112. h.connects.Range(func(key, _ any) bool {
  113. line := key.(*Line)
  114. if line.state == Connected {
  115. cs = append(cs, line.channel)
  116. }
  117. return true
  118. })
  119. return cs
  120. }
  121. // 获取所有连接频道和连接时长
  122. // 为了避免定义数据结构麻烦,采用|隔开
  123. func (h *Hub) AllChannelTime() []string {
  124. cs := make([]string, 0)
  125. h.connects.Range(func(key, value any) bool {
  126. line := key.(*Line)
  127. if line.state == Connected {
  128. ti := time.Since(value.(time.Time)).Milliseconds()
  129. cs = append(cs, line.channel+"|"+strconv.FormatInt(ti, 10))
  130. }
  131. return true
  132. })
  133. return cs
  134. }
  135. // 获取频道并通过函数过滤,如果返回 false 将终止
  136. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  137. h.connects.Range(func(key, _ any) bool {
  138. line := key.(*Line)
  139. if line.state == Connected {
  140. return fn(line.channel)
  141. }
  142. return true
  143. })
  144. }
  145. // 从 channel 获取连接
  146. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  147. h.connects.Range(func(key, _ any) bool {
  148. l := key.(*Line)
  149. if l.channel == channel {
  150. line = l
  151. return false
  152. }
  153. return true
  154. })
  155. return
  156. }
  157. // 返回请求结果
  158. func (h *Hub) outResponse(response *ResponseData) {
  159. defer recover() //避免管道已经关闭而引起panic
  160. id := response.Id
  161. t, ok := h.msgCache.Load(id)
  162. if ok {
  163. // 删除数据缓存
  164. h.msgCache.Delete(id)
  165. gm := t.(*GetMsg)
  166. // 停止定时器
  167. if !gm.timer.Stop() {
  168. select {
  169. case <-gm.timer.C:
  170. default:
  171. }
  172. }
  173. // 回应数据到上层
  174. gm.out <- response
  175. }
  176. }
  177. // 发送数据到网络接口
  178. // 返回发送的数量
  179. func (h *Hub) sendRequest(gd *GetData) (count int) {
  180. h.connects.Range(func(key, _ any) bool {
  181. conn := key.(*Line)
  182. // 检查连接是否OK
  183. if conn.state != Connected {
  184. return true
  185. }
  186. if gd.Channel.MatchString(conn.channel) {
  187. var id uint16
  188. if gd.backchan != nil {
  189. id = h.GetID()
  190. timeout := gd.Timeout
  191. if timeout <= 0 {
  192. timeout = h.cf.WriteWait
  193. }
  194. fn := func(id uint16, conn *Line) func() {
  195. return func() {
  196. go h.outResponse(&ResponseData{
  197. Id: id,
  198. State: config.GET_TIMEOUT,
  199. Data: []byte(config.GET_TIMEOUT_MSG),
  200. conn: conn,
  201. })
  202. // 检查是否已经很久时间没有使用连接了
  203. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3)*time.Millisecond {
  204. // 超时关闭当前的连接
  205. log.Println("get message timeout", conn.channel)
  206. // 有可能连接出现问题,断开并重新连接
  207. conn.Close(false)
  208. return
  209. }
  210. }
  211. }(id, conn)
  212. // 将要发送的请求缓存
  213. gm := &GetMsg{
  214. out: gd.backchan,
  215. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  216. }
  217. h.msgCache.Store(id, gm)
  218. }
  219. // 组织数据并发送到Connect
  220. conn.sendRequest <- &RequestData{
  221. Id: id,
  222. Cmd: gd.Cmd,
  223. Data: gd.Data,
  224. timeout: gd.Timeout,
  225. backchan: gd.backchan,
  226. conn: conn,
  227. }
  228. log.Println("[SEND]->", conn.channel, "["+gd.Cmd+"]", subStr(string(gd.Data), 200))
  229. count++
  230. if gd.Max > 0 && count >= gd.Max {
  231. return false
  232. }
  233. }
  234. return true
  235. })
  236. return
  237. }
  238. // 执行网络发送过来的命令
  239. func (h *Hub) requestFromNet(request *RequestData) {
  240. cmd := request.Cmd
  241. channel := request.conn.channel
  242. log.Println("[REQU]<-", channel, "["+cmd+"]", subStr(string(request.Data), 200))
  243. sub, ok := h.subscribes.Load(cmd)
  244. if ok {
  245. subs := sub.([]*SubscribeData)
  246. // 倒序查找是为了新增的频道响应函数优先执行
  247. for i := len(subs) - 1; i >= 0; i-- {
  248. rg := subs[i]
  249. if rg.Channel.MatchString(channel) {
  250. state, data := rg.BackFunc(request)
  251. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  252. if state == config.NEXT_SUBSCRIBE {
  253. continue
  254. }
  255. // 如果id为0表示不需要回应
  256. if request.Id != 0 {
  257. request.conn.sendResponse <- &ResponseData{
  258. Id: request.Id,
  259. State: state,
  260. Data: data,
  261. }
  262. log.Println("[RESP]->", channel, "["+cmd+"]", state, subStr(string(data), 200))
  263. }
  264. return
  265. }
  266. }
  267. }
  268. log.Println("[not match command]", channel, cmd)
  269. // 返回没有匹配的消息
  270. request.conn.sendResponse <- &ResponseData{
  271. Id: request.Id,
  272. State: config.NO_MATCH,
  273. Data: []byte(config.NO_MATCH_MSG),
  274. }
  275. }
  276. // 请求频道并获取数据,采用回调的方式返回结果
  277. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  278. // 如果 backFunc 返回为 false 则提前结束
  279. func (h *Hub) Get(gd *GetData, backFunc GetBack) (count int) {
  280. // 排除空频道
  281. if gd.Channel == nil {
  282. return 0
  283. }
  284. if gd.Timeout <= 0 {
  285. gd.Timeout = h.cf.ReadWait
  286. }
  287. if gd.backchan == nil {
  288. gd.backchan = make(chan *ResponseData, 32)
  289. }
  290. max := h.sendRequest(gd)
  291. if max <= 0 {
  292. return 0
  293. }
  294. // 避免出现异常时线程无法退出
  295. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  296. defer func() {
  297. if !timer.Stop() {
  298. select {
  299. case <-timer.C:
  300. default:
  301. }
  302. }
  303. close(gd.backchan)
  304. }()
  305. for {
  306. select {
  307. case rp := <-gd.backchan:
  308. if rp == nil || rp.conn == nil {
  309. // 可能是已经退出了
  310. return
  311. }
  312. ch := rp.conn.channel
  313. log.Println("[RECV]<-", ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  314. count++
  315. // 如果这里返回为false这跳出循环
  316. if backFunc != nil && !backFunc(rp) {
  317. return
  318. }
  319. if count >= max {
  320. return
  321. }
  322. case <-timer.C:
  323. return
  324. }
  325. }
  326. // return
  327. }
  328. // 只获取一个频道的数据,阻塞等待
  329. // 如果没有结果将返回 NO_MATCH
  330. func (h *Hub) GetOne(cmd *GetData) (response *ResponseData) {
  331. cmd.Max = 1
  332. h.Get(cmd, func(rp *ResponseData) (ok bool) {
  333. response = rp
  334. return false
  335. })
  336. if response == nil {
  337. response = &ResponseData{
  338. State: config.NO_MATCH,
  339. Data: []byte(config.NO_MATCH_MSG),
  340. }
  341. }
  342. return
  343. }
  344. // 推送消息出去,不需要返回数据
  345. func (h *Hub) Push(cmd *GetData) {
  346. cmd.backchan = nil
  347. h.sendRequest(cmd)
  348. }
  349. // 增加连接
  350. func (h *Hub) addLine(line *Line) {
  351. if _, ok := h.connects.Load(line); ok {
  352. log.Println("connect have exist")
  353. // 连接已经存在,直接返回
  354. return
  355. }
  356. // 检查是否有相同的channel,如果有的话将其关闭删除
  357. channel := line.channel
  358. h.connects.Range(func(key, _ any) bool {
  359. conn := key.(*Line)
  360. // 删除超时的连接
  361. if conn.state != Connected && conn.host == nil && time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*5)*time.Millisecond {
  362. h.connects.Delete(key)
  363. return true
  364. }
  365. if conn.channel == channel {
  366. conn.Close(true)
  367. h.connects.Delete(key)
  368. return false
  369. }
  370. return true
  371. })
  372. h.connects.Store(line, true)
  373. }
  374. // 删除连接
  375. func (h *Hub) removeLine(conn *Line) {
  376. conn.Close(true)
  377. h.connects.Delete(conn)
  378. }
  379. // 获取指定连接的连接持续时间
  380. func (h *Hub) ConnectDuration(conn *Line) time.Duration {
  381. t, ok := h.connects.Load(conn)
  382. if ok {
  383. return time.Since(t.(time.Time))
  384. }
  385. // 如果不存在直接返回0
  386. return time.Duration(0)
  387. }
  388. // 绑定端口,建立服务
  389. // 需要程序运行时调用
  390. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  391. doConnectFunc := func(conn conn.Connect) {
  392. proto, version, channel, auth, err := conn.ReadAuthInfo()
  393. if err != nil {
  394. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  395. conn.Close()
  396. return
  397. }
  398. if version != info.Version || proto != info.Proto {
  399. log.Println("wrong version or protocol: ", version, proto)
  400. conn.Close()
  401. return
  402. }
  403. // 检查验证是否合法
  404. if !h.checkAuthFunc(proto, version, channel, auth) {
  405. conn.Close()
  406. return
  407. }
  408. // 发送频道信息
  409. if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
  410. log.Println("[WriteAuthInfo ERROR]", err)
  411. conn.Close()
  412. return
  413. }
  414. // 将连接加入现有连接中
  415. done := false
  416. h.connects.Range(func(key, _ any) bool {
  417. line := key.(*Line)
  418. if line.state == Disconnected && line.channel == channel && line.host == nil {
  419. line.Start(conn, nil)
  420. done = true
  421. return false
  422. }
  423. return true
  424. })
  425. // 新建一个连接
  426. if !done {
  427. line := NewConnect(h.cf, h, channel, conn, nil)
  428. h.addLine(line)
  429. }
  430. }
  431. if info.Version == wsv2.VERSION && info.Proto == wsv2.PROTO {
  432. bind := ""
  433. if info.Bind != "" {
  434. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  435. }
  436. return wsv2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  437. } else if info.Version == tpv2.VERSION && info.Proto == tpv2.PROTO {
  438. return tpv2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  439. }
  440. return errors.New("not connect protocol and version found")
  441. }
  442. // 新建一个连接,不同的连接协议由底层自己选择
  443. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  444. func (h *Hub) ConnectToServer(channel string, force bool) (err error) {
  445. // 检查当前channel是否已经存在
  446. if !force {
  447. line := h.ChannelToLine(channel)
  448. if line != nil && line.state == Connected {
  449. err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  450. return
  451. }
  452. }
  453. // 获取服务地址等信息
  454. host, err := h.connectHostFunc(channel)
  455. if err != nil {
  456. return err
  457. }
  458. var conn conn.Connect
  459. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  460. if host.Version == wsv2.VERSION && host.Proto == wsv2.PROTO {
  461. conn, err = wsv2.Client(h.cf, addr, host.Path, host.Hash)
  462. } else if host.Version == tpv2.VERSION && host.Proto == tpv2.PROTO {
  463. conn, err = tpv2.Client(h.cf, addr, host.Hash)
  464. } else {
  465. return fmt.Errorf("not correct protocol and version found in: %+v", host)
  466. }
  467. if err != nil {
  468. log.Println("[Client ERROR]", host.Proto, err)
  469. host.Errors++
  470. host.Updated = time.Now()
  471. return err
  472. }
  473. // 发送验证信息
  474. if err := conn.WriteAuthInfo(h.channel, h.authFunc(host.Proto, host.Version, channel, nil)); err != nil {
  475. log.Println("[WriteAuthInfo ERROR]", err)
  476. conn.Close()
  477. host.Errors++
  478. host.Updated = time.Now()
  479. return err
  480. }
  481. // 接收频道信息
  482. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  483. if err != nil {
  484. log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
  485. conn.Close()
  486. host.Errors++
  487. host.Updated = time.Now()
  488. return err
  489. }
  490. // 检查版本和协议是否一致
  491. if version != host.Version || proto != host.Proto {
  492. err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
  493. log.Println(err)
  494. conn.Close()
  495. host.Errors++
  496. host.Updated = time.Now()
  497. return err
  498. }
  499. // 检查频道名称是否匹配
  500. if !strings.Contains(channel2, channel) {
  501. err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
  502. log.Println(err)
  503. conn.Close()
  504. host.Errors++
  505. host.Updated = time.Now()
  506. return err
  507. }
  508. // 检查验证是否合法
  509. if !h.checkAuthFunc(proto, version, channel, auth) {
  510. err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  511. log.Println(err)
  512. conn.Close()
  513. host.Errors++
  514. host.Updated = time.Now()
  515. return err
  516. }
  517. // 更新服务主机信息
  518. host.Errors = 0
  519. host.Updated = time.Now()
  520. // 将连接加入现有连接中
  521. done := false
  522. h.connects.Range(func(key, _ any) bool {
  523. line := key.(*Line)
  524. if line.channel == channel {
  525. if line.state == Connected {
  526. if !force {
  527. err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
  528. log.Println(err)
  529. return false
  530. }
  531. return true
  532. }
  533. line.Start(conn, host)
  534. done = true
  535. return false
  536. }
  537. return true
  538. })
  539. if err != nil {
  540. return err
  541. }
  542. // 新建一个连接
  543. if !done {
  544. line := NewConnect(h.cf, h, channel, conn, host)
  545. h.addLine(line)
  546. }
  547. return nil
  548. }
  549. // 重试方式连接服务
  550. // 将会一直阻塞直到连接成功
  551. func (h *Hub) ConnectToServerX(channel string, force bool) {
  552. for {
  553. err := h.ConnectToServer(channel, force)
  554. if err == nil {
  555. return
  556. }
  557. log.Println("[ConnectToServer ERROR, try it again]", err)
  558. // 产生一个随机数避免刹间重连过载
  559. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  560. time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
  561. }
  562. }
  563. // 建立一个集线器
  564. // connectFunc 用于监听连接状态的函数,可以为nil
  565. func NewHub(
  566. cf *config.Config,
  567. channel string,
  568. // 客户端需要用的函数
  569. connectHostFunc ConnectHostFunc,
  570. authFunc AuthFunc,
  571. // 服务端需要用的函数
  572. checkAuthFunc CheckAuthFunc,
  573. // 连接状态变化时调用的函数
  574. connectStatusFunc ConnectStatusFunc,
  575. ) (h *Hub) {
  576. h = &Hub{
  577. cf: cf,
  578. channel: channel,
  579. connectHostFunc: connectHostFunc,
  580. authFunc: authFunc,
  581. checkAuthFunc: checkAuthFunc,
  582. connectStatusFunc: connectStatusFunc,
  583. lastCleanDeadConnect: time.Now().UnixMilli(),
  584. }
  585. return h
  586. }