hub.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854
  1. package tinymq
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. // "regexp"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. "git.me9.top/git/tinymq/config"
  16. "git.me9.top/git/tinymq/conn"
  17. "git.me9.top/git/tinymq/conn/tcp2"
  18. "git.me9.top/git/tinymq/conn/ws2"
  19. )
  20. // 类似一个插座的功能,管理多个连接
  21. // 一个hub即可以是客户端,同时也可以是服务端
  22. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  23. // 截取部分字符串
  24. func subStr(str string, length int) string {
  25. if len(str) <= length {
  26. return str
  27. }
  28. return str[0:length] + "..."
  29. }
  30. type Hub struct {
  31. sync.Mutex
  32. cf *config.Config
  33. globalID uint16
  34. channel string // 本地频道信息
  35. middle []MiddleFunc // 中间件
  36. connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
  37. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  38. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  39. // 客户端需要用的函数
  40. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  41. authFunc AuthFunc // 获取认证信息,用于发送给对方
  42. // 服务端需要用的函数
  43. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  44. // 连接状态变化时调用的函数
  45. connectStatusFunc ConnectStatusFunc
  46. // 验证发送的数据条件是否满足 (可为空)
  47. checkConnectOkFunc CheckConnectOkFunc
  48. // 上次清理异常连接时间戳
  49. lastCleanDeadConnect int64
  50. }
  51. // 转换数据
  52. func (h *Hub) convertData(data any) (reqData []byte, err error) {
  53. switch data := data.(type) {
  54. case []byte:
  55. reqData = data
  56. case string:
  57. reqData = []byte(data)
  58. case func() ([]byte, error):
  59. reqData, err = data()
  60. if err != nil {
  61. log.Println(err.Error())
  62. return nil, err
  63. }
  64. default:
  65. if data != nil {
  66. // if s, ok := data.(func() ([]uint8, error)); ok {
  67. // return s()
  68. // }
  69. // 自动转换数据为json格式
  70. reqData, err = json.Marshal(data)
  71. if err != nil {
  72. log.Println(err.Error())
  73. return nil, err
  74. }
  75. }
  76. }
  77. return
  78. }
  79. // 清理异常连接
  80. func (h *Hub) cleanDeadConnect() {
  81. h.Lock()
  82. defer h.Unlock()
  83. now := time.Now().UnixMilli()
  84. if now-h.lastCleanDeadConnect > int64(h.cf.CleanDeadConnectWait) {
  85. h.lastCleanDeadConnect = now
  86. h.connects.Range(func(key, _ any) bool {
  87. line := key.(*Line)
  88. if line.state != Connected && now-line.updated.UnixMilli() > int64(h.cf.CleanDeadConnectWait) {
  89. h.connects.Delete(key)
  90. }
  91. return true
  92. })
  93. }
  94. }
  95. // 获取通讯消息ID号
  96. func (h *Hub) GetID() uint16 {
  97. h.Lock()
  98. defer h.Unlock()
  99. h.globalID++
  100. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  101. h.globalID = 1
  102. }
  103. for {
  104. // 检查是否在请求队列中存在对应的id
  105. if _, ok := h.msgCache.Load(h.globalID); ok {
  106. h.globalID++
  107. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  108. h.globalID = 1
  109. }
  110. } else {
  111. break
  112. }
  113. }
  114. return h.globalID
  115. }
  116. // 添加中间件
  117. // 如果中间件函数返回为空,表示处理完成,通过
  118. // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过
  119. func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
  120. h.middle = append(h.middle, middleFunc)
  121. }
  122. // 注册频道,其中频道为正则表达式字符串
  123. func (h *Hub) Subscribe(filter FilterFunc, cmd string, backFunc SubscribeBackFunc) (err error) {
  124. if filter == nil {
  125. return errors.New("filter function can not be nil")
  126. }
  127. reg := &SubscribeData{
  128. Filter: filter,
  129. Cmd: cmd,
  130. BackFunc: backFunc,
  131. }
  132. sub, ok := h.subscribes.Load(cmd)
  133. if ok {
  134. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  135. return
  136. }
  137. regs := make([]*SubscribeData, 1)
  138. regs[0] = reg
  139. h.subscribes.Store(cmd, regs)
  140. return
  141. }
  142. // 遍历频道列表
  143. // 如果 fn 返回 false,则 range 停止迭代
  144. func (h *Hub) ConnectRange(fn func(line *Line) bool) {
  145. h.connects.Range(func(key, _ any) bool {
  146. line := key.(*Line)
  147. return fn(line)
  148. })
  149. }
  150. // 获取当前在线的数量
  151. func (h *Hub) ConnectNum() int {
  152. var count int
  153. h.connects.Range(func(key, _ any) bool {
  154. if key.(*Line).state == Connected {
  155. count++
  156. }
  157. return true
  158. })
  159. return count
  160. }
  161. // 获取所有的在线连接频道
  162. func (h *Hub) AllChannel() []string {
  163. cs := make([]string, 0)
  164. h.connects.Range(func(key, _ any) bool {
  165. line := key.(*Line)
  166. if line.state == Connected {
  167. cs = append(cs, line.channel)
  168. }
  169. return true
  170. })
  171. return cs
  172. }
  173. // 获取所有连接频道和连接时长
  174. // 为了避免定义数据结构麻烦,采用|隔开, 频道名|连接开始时间
  175. func (h *Hub) AllChannelWithStarted() []string {
  176. cs := make([]string, 0)
  177. h.connects.Range(func(key, _ any) bool {
  178. line := key.(*Line)
  179. if line.state == Connected {
  180. // ti := time.Since(line.started).Milliseconds()
  181. cs = append(cs, line.channel+"|"+strconv.FormatInt(line.started.UnixMilli(), 10))
  182. }
  183. return true
  184. })
  185. return cs
  186. }
  187. // 获取频道并通过函数过滤,如果返回 false 将终止
  188. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  189. h.connects.Range(func(key, _ any) bool {
  190. line := key.(*Line)
  191. if line.state == Connected {
  192. return fn(line.channel)
  193. }
  194. return true
  195. })
  196. }
  197. // 从 channel 获取连接
  198. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  199. h.connects.Range(func(key, _ any) bool {
  200. l := key.(*Line)
  201. if l.channel == channel {
  202. line = l
  203. return false
  204. }
  205. return true
  206. })
  207. return
  208. }
  209. // 返回请求结果
  210. func (h *Hub) outResponse(response *ResponseData) {
  211. defer recover() //避免管道已经关闭而引起panic
  212. id := response.Id
  213. t, ok := h.msgCache.Load(id)
  214. if ok {
  215. // 删除数据缓存
  216. h.msgCache.Delete(id)
  217. gm := t.(*GetMsg)
  218. // 停止定时器
  219. if !gm.timer.Stop() {
  220. select {
  221. case <-gm.timer.C:
  222. default:
  223. }
  224. }
  225. // 回应数据到上层
  226. gm.out <- response
  227. }
  228. }
  229. // 发送数据到网络接口
  230. // 返回发送的数量
  231. func (h *Hub) sendRequest(gd *GetData) (count int) {
  232. outData, err := h.convertData(gd.Data)
  233. if err != nil {
  234. log.Println(err)
  235. return 0
  236. }
  237. // 增加如果没有发送到消息,延时100ms重连直到超时
  238. for i := 0; i < gd.Timeout; i += 100 {
  239. h.connects.Range(func(key, _ any) bool {
  240. conn := key.(*Line)
  241. // 检查连接是否OK
  242. if conn.state != Connected {
  243. return true
  244. }
  245. // 验证连接是否达到发送数据的要求
  246. if h.checkConnectOkFunc != nil && !h.checkConnectOkFunc(conn, gd) {
  247. return true
  248. }
  249. if gd.Filter(conn) {
  250. var id uint16
  251. if gd.backchan != nil {
  252. id = h.GetID()
  253. timeout := gd.Timeout
  254. if timeout <= 0 {
  255. timeout = h.cf.WriteWait
  256. }
  257. fn := func(id uint16, conn *Line) func() {
  258. return func() {
  259. go h.outResponse(&ResponseData{
  260. Id: id,
  261. State: config.GET_TIMEOUT,
  262. Data: fmt.Appendf(nil, "[%s] %s %s", config.GET_TIMEOUT_MSG, conn.channel, gd.Cmd),
  263. conn: conn,
  264. })
  265. // 检查是否已经很久时间没有使用连接了
  266. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) {
  267. // 超时关闭当前的连接
  268. log.Println("get message timeout", conn.channel)
  269. // 有可能连接出现问题,断开并重新连接
  270. conn.Close(false)
  271. return
  272. }
  273. }
  274. }(id, conn)
  275. // 将要发送的请求缓存
  276. gm := &GetMsg{
  277. out: gd.backchan,
  278. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  279. }
  280. h.msgCache.Store(id, gm)
  281. }
  282. // 组织数据并发送到Connect
  283. conn.sendRequest <- &RequestData{
  284. Id: id,
  285. Cmd: gd.Cmd,
  286. Data: outData,
  287. timeout: gd.Timeout,
  288. backchan: gd.backchan,
  289. conn: conn,
  290. }
  291. if h.cf.PrintMsg {
  292. log.Println("[SEND]->", id, conn.channel, "["+gd.Cmd+"]", subStr(string(outData), 200))
  293. }
  294. count++
  295. if gd.Max > 0 && count >= gd.Max {
  296. return false
  297. }
  298. }
  299. return true
  300. })
  301. if count > 0 {
  302. break
  303. }
  304. time.Sleep(time.Millisecond * 100)
  305. }
  306. return
  307. }
  308. // 执行网络发送过来的命令
  309. func (h *Hub) requestFromNet(request *RequestData) {
  310. cmd := request.Cmd
  311. channel := request.conn.channel
  312. if h.cf.PrintMsg {
  313. log.Println("[REQU]<-", request.Id, channel, "["+cmd+"]", subStr(string(request.Data), 200))
  314. }
  315. // 执行中间件
  316. for _, mdFunc := range h.middle {
  317. rsp := mdFunc(request)
  318. if rsp != nil {
  319. // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理
  320. if rsp.State == config.NEXT_MIDDLE {
  321. continue
  322. }
  323. // 返回消息
  324. if request.Id != 0 {
  325. rsp.Id = request.Id
  326. request.conn.sendResponse <- rsp
  327. }
  328. return
  329. } else {
  330. break
  331. }
  332. }
  333. sub, ok := h.subscribes.Load(cmd)
  334. if ok {
  335. subs := sub.([]*SubscribeData)
  336. // 倒序查找是为了新增的频道响应函数优先执行
  337. for i := len(subs) - 1; i >= 0; i-- {
  338. rg := subs[i]
  339. // if rg.Channel.MatchString(channel) {
  340. if rg.Filter(request.conn) {
  341. state, data := rg.BackFunc(request)
  342. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  343. if state == config.NEXT_SUBSCRIBE {
  344. continue
  345. }
  346. var byteData []byte
  347. switch data := data.(type) {
  348. case []byte:
  349. byteData = data
  350. case string:
  351. byteData = []byte(data)
  352. default:
  353. if data != nil {
  354. // 自动转换数据为json格式
  355. var err error
  356. byteData, err = json.Marshal(data)
  357. if err != nil {
  358. log.Println(err.Error())
  359. state = config.CONVERT_FAILED
  360. byteData = fmt.Appendf(nil, "[%s] %s %s", config.CONVERT_FAILED_MSG, request.conn.channel, request.Cmd)
  361. }
  362. }
  363. }
  364. // 如果id为0表示不需要回应
  365. if request.Id != 0 {
  366. request.conn.sendResponse <- &ResponseData{
  367. Id: request.Id,
  368. State: state,
  369. Data: byteData,
  370. }
  371. if h.cf.PrintMsg {
  372. log.Println("[RESP]->", request.Id, channel, "["+cmd+"]", state, subStr(string(byteData), 200))
  373. }
  374. }
  375. return
  376. }
  377. }
  378. }
  379. log.Println("[not match command]", channel, cmd)
  380. // 返回没有匹配的消息
  381. request.conn.sendResponse <- &ResponseData{
  382. Id: request.Id,
  383. State: config.NO_MATCH,
  384. Data: fmt.Appendf(nil, "[%s] %s %s", config.NO_MATCH_MSG, channel, cmd),
  385. }
  386. }
  387. // 请求频道并获取数据,采用回调的方式返回结果
  388. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  389. // 如果 backFunc 返回为 false 则提前结束
  390. // 最大数量和超时时间如果为0的话表示使用默认值
  391. func (h *Hub) GetWithMaxAndTimeout(filter FilterFunc, cmd string, data any, backFunc GetBackFunc, max int, timeout int) (count int) {
  392. // 排除空频道
  393. if filter == nil {
  394. return 0
  395. }
  396. if timeout <= 0 {
  397. timeout = h.cf.WriteWait
  398. }
  399. gd := &GetData{
  400. Filter: filter,
  401. Cmd: cmd,
  402. Data: data,
  403. Max: max,
  404. Timeout: timeout,
  405. backchan: make(chan *ResponseData, 32),
  406. }
  407. sendMax := h.sendRequest(gd)
  408. if sendMax <= 0 {
  409. return 0
  410. }
  411. // 避免出现异常时线程无法退出
  412. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  413. defer func() {
  414. if !timer.Stop() {
  415. select {
  416. case <-timer.C:
  417. default:
  418. }
  419. }
  420. close(gd.backchan)
  421. }()
  422. for {
  423. select {
  424. case rp := <-gd.backchan:
  425. if rp == nil || rp.conn == nil {
  426. // 可能是已经退出了
  427. return
  428. }
  429. ch := rp.conn.channel
  430. if h.cf.PrintMsg {
  431. log.Println("[RECV]<-", rp.Id, ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  432. }
  433. count++
  434. // 如果这里返回为false这跳出循环
  435. if backFunc != nil && !backFunc(rp) {
  436. return
  437. }
  438. if count >= sendMax {
  439. return
  440. }
  441. case <-timer.C:
  442. return
  443. }
  444. }
  445. // return
  446. }
  447. // 请求频道并获取数据,采用回调的方式返回结果
  448. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  449. // 如果 backFunc 返回为 false 则提前结束
  450. func (h *Hub) Get(filter FilterFunc, cmd string, data any, backFunc GetBackFunc) (count int) {
  451. return h.GetWithMaxAndTimeout(filter, cmd, data, backFunc, 0, 0)
  452. }
  453. // 只获取一个频道的数据,阻塞等待到默认超时间隔
  454. // 如果没有结果将返回 NO_MATCH
  455. func (h *Hub) GetOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  456. h.GetWithMaxAndTimeout(filter, cmd, data, func(rp *ResponseData) (ok bool) {
  457. response = rp
  458. return false
  459. }, 1, 0)
  460. if response == nil {
  461. response = &ResponseData{
  462. State: config.CONNECT_NO_MATCH,
  463. Data: fmt.Appendf(nil, "[%s] %s", config.CONNECT_NO_MATCH_MSG, cmd),
  464. }
  465. }
  466. return
  467. }
  468. // 只获取一个频道的数据,阻塞等待到指定超时间隔
  469. // 如果没有结果将返回 NO_MATCH
  470. func (h *Hub) GetOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  471. h.GetWithMaxAndTimeout(filter, cmd, data, func(rp *ResponseData) (ok bool) {
  472. response = rp
  473. return false
  474. }, 1, timeout)
  475. if response == nil {
  476. response = &ResponseData{
  477. State: config.CONNECT_NO_MATCH,
  478. Data: fmt.Appendf(nil, "[%s] %s", config.CONNECT_NO_MATCH_MSG, cmd),
  479. }
  480. }
  481. return
  482. }
  483. // 推送消息出去,不需要返回数据
  484. func (h *Hub) Push(filter FilterFunc, cmd string, data any) {
  485. h.PushWithMax(filter, cmd, data, 0)
  486. }
  487. // 推送最大对应数量的消息出去,不需要返回数据
  488. func (h *Hub) PushWithMax(filter FilterFunc, cmd string, data any, max int) {
  489. // 排除空频道
  490. if filter == nil {
  491. return
  492. }
  493. gd := &GetData{
  494. Filter: filter,
  495. Cmd: cmd,
  496. Data: data,
  497. Max: max,
  498. Timeout: h.cf.WriteWait,
  499. backchan: nil,
  500. }
  501. h.sendRequest(gd)
  502. }
  503. // 增加连接
  504. func (h *Hub) addLine(line *Line) {
  505. if _, ok := h.connects.Load(line); ok {
  506. log.Println("connect have exist")
  507. // 连接已经存在,直接返回
  508. return
  509. }
  510. // 检查是否有相同的channel,如果有的话将其关闭删除
  511. channel := line.channel
  512. h.connects.Range(func(key, _ any) bool {
  513. conn := key.(*Line)
  514. // 删除超时的连接
  515. if conn.state != Connected && conn.host == nil && time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*5*int(time.Millisecond)) {
  516. h.connects.Delete(key)
  517. return true
  518. }
  519. if conn.channel == channel {
  520. conn.Close(true)
  521. h.connects.Delete(key)
  522. return false
  523. }
  524. return true
  525. })
  526. h.connects.Store(line, true)
  527. }
  528. // 删除连接
  529. func (h *Hub) removeLine(conn *Line) {
  530. conn.Close(true)
  531. h.connects.Delete(conn)
  532. }
  533. // 获取指定连接的连接持续时间
  534. func (h *Hub) ConnectDuration(conn *Line) time.Duration {
  535. t, ok := h.connects.Load(conn)
  536. if ok {
  537. return time.Since(t.(time.Time))
  538. }
  539. // 如果不存在直接返回0
  540. return time.Duration(0)
  541. }
  542. // 绑定端口,建立服务
  543. // 需要程序运行时调用
  544. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  545. doConnectFunc := func(conn conn.Connect) {
  546. proto, version, channel, auth, err := conn.ReadAuthInfo()
  547. if err != nil {
  548. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  549. conn.Close()
  550. return
  551. }
  552. if version != info.Version || proto != info.Proto {
  553. log.Println("wrong version or protocol: ", version, proto)
  554. conn.Close()
  555. return
  556. }
  557. // 检查验证是否合法
  558. if !h.checkAuthFunc(proto, version, channel, auth) {
  559. conn.Close()
  560. return
  561. }
  562. // 发送频道信息
  563. if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
  564. log.Println("[WriteAuthInfo ERROR]", err)
  565. conn.Close()
  566. return
  567. }
  568. // 将连接加入现有连接中
  569. done := false
  570. h.connects.Range(func(key, _ any) bool {
  571. line := key.(*Line)
  572. if line.state == Disconnected && line.host == nil && line.ChannelEqualWithoutPrefix(channel) {
  573. line.Start(conn, nil)
  574. done = true
  575. return false
  576. }
  577. return true
  578. })
  579. // 新建一个连接
  580. if !done {
  581. line := NewConnect(h.cf, h, channel, conn, nil)
  582. h.addLine(line)
  583. }
  584. }
  585. if info.Version == ws2.VERSION && info.Proto == ws2.PROTO {
  586. bind := ""
  587. if info.Bind != "" {
  588. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  589. }
  590. return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  591. } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO {
  592. return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  593. }
  594. return errors.New("not connect protocol and version found")
  595. }
  596. // 新建一个连接,不同的连接协议由底层自己选择
  597. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  598. func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err error) {
  599. // 检查当前channel是否已经存在
  600. if !force {
  601. line := h.ChannelToLine(channel)
  602. if line != nil && line.state == Connected {
  603. err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  604. return
  605. }
  606. }
  607. if host == nil {
  608. // 获取服务地址等信息
  609. host, err = h.connectHostFunc(channel, Both)
  610. if err != nil {
  611. return err
  612. }
  613. }
  614. var conn conn.Connect
  615. var runProto string
  616. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  617. // 添加定时器
  618. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
  619. defer cancel()
  620. taskCh := make(chan bool)
  621. done := false
  622. go func() {
  623. if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
  624. runProto = ws2.PROTO
  625. conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
  626. } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
  627. runProto = tcp2.PROTO
  628. conn, err = tcp2.Dial(h.cf, addr, host.Hash)
  629. } else {
  630. err = fmt.Errorf("not correct protocol and version found in: %+v", host)
  631. }
  632. if done {
  633. if err != nil {
  634. log.Println("[Dial ERROR]", err)
  635. }
  636. if conn != nil {
  637. conn.Close()
  638. }
  639. } else {
  640. taskCh <- err == nil
  641. }
  642. }()
  643. select {
  644. case ok := <-taskCh:
  645. cancel()
  646. if !ok || err != nil || conn == nil {
  647. log.Println("[Client ERROR]", host.Proto, err)
  648. host.Errors++
  649. host.Updated = time.Now()
  650. if err == nil {
  651. err = errors.New("unknown error")
  652. }
  653. return err
  654. }
  655. case <-ctx.Done():
  656. done = true
  657. return errors.New("timeout")
  658. }
  659. // 发送验证信息
  660. if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
  661. log.Println("[WriteAuthInfo ERROR]", err)
  662. conn.Close()
  663. host.Errors++
  664. host.Updated = time.Now()
  665. return err
  666. }
  667. // 接收频道信息
  668. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  669. if err != nil {
  670. log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
  671. conn.Close()
  672. host.Errors++
  673. host.Updated = time.Now()
  674. return err
  675. }
  676. // 检查版本和协议是否一致
  677. if version != host.Version || proto != runProto {
  678. err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
  679. log.Println(err)
  680. conn.Close()
  681. host.Errors++
  682. host.Updated = time.Now()
  683. return err
  684. }
  685. // 检查频道名称是否匹配
  686. if !strings.Contains(channel2, channel) {
  687. err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
  688. log.Println(err)
  689. conn.Close()
  690. host.Errors++
  691. host.Updated = time.Now()
  692. return err
  693. }
  694. // 检查验证是否合法
  695. if !h.checkAuthFunc(proto, version, channel, auth) {
  696. err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  697. log.Println(err)
  698. conn.Close()
  699. host.Errors++
  700. host.Updated = time.Now()
  701. return err
  702. }
  703. // 更新服务主机信息
  704. host.Errors = 0
  705. host.Updated = time.Now()
  706. // 将连接加入现有连接中
  707. done = false
  708. h.connects.Range(func(key, _ any) bool {
  709. line := key.(*Line)
  710. if line.channel == channel {
  711. if line.state == Connected {
  712. if force {
  713. line.Close(true)
  714. } else {
  715. err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
  716. log.Println(err)
  717. return false
  718. }
  719. }
  720. line.Start(conn, host)
  721. done = true
  722. return false
  723. }
  724. return true
  725. })
  726. if err != nil {
  727. return err
  728. }
  729. // 新建一个连接
  730. if !done {
  731. line := NewConnect(h.cf, h, channel, conn, host)
  732. h.addLine(line)
  733. }
  734. return nil
  735. }
  736. // 重试方式连接服务
  737. // 将会一直阻塞直到连接成功
  738. func (h *Hub) ConnectToServerX(channel string, force bool, host *HostInfo) {
  739. if host == nil {
  740. host, _ = h.connectHostFunc(channel, Direct)
  741. }
  742. for {
  743. err := h.ConnectToServer(channel, force, host)
  744. if err == nil {
  745. return
  746. }
  747. log.Println("[ConnectToServer ERROR, try it again]", channel, host, err)
  748. host = nil
  749. // 产生一个随机数避免刹间重连过载
  750. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  751. time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
  752. }
  753. }
  754. // 检测处理代理连接
  755. func (h *Hub) checkProxyConnect() {
  756. if h.cf.ProxyTimeout <= 0 {
  757. return
  758. }
  759. proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond)))
  760. for {
  761. <-proxyTicker.C
  762. now := time.Now().UnixMilli()
  763. h.connects.Range(func(key, _ any) bool {
  764. line := key.(*Line)
  765. if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) {
  766. host, err := h.connectHostFunc(line.channel, Direct)
  767. if err != nil {
  768. log.Println("[checkProxyConnect connectHostFunc ERROR]", err)
  769. return false
  770. }
  771. err = h.ConnectToServer(line.channel, true, host)
  772. if err != nil {
  773. log.Println("[checkProxyConnect ConnectToServer WARNING]", err)
  774. }
  775. }
  776. return true
  777. })
  778. }
  779. }
  780. // 建立一个集线器
  781. // connectFunc 用于监听连接状态的函数,可以为nil
  782. func NewHub(
  783. cf *config.Config,
  784. channel string,
  785. // 客户端需要用的函数
  786. connectHostFunc ConnectHostFunc,
  787. authFunc AuthFunc,
  788. // 服务端需要用的函数
  789. checkAuthFunc CheckAuthFunc,
  790. // 连接状态变化时调用的函数
  791. connectStatusFunc ConnectStatusFunc,
  792. // 验证发送的数据条件是否满足 (可为空)
  793. checkConnectOkFunc CheckConnectOkFunc,
  794. ) (h *Hub) {
  795. h = &Hub{
  796. cf: cf,
  797. globalID: uint16(time.Now().UnixNano()) % config.ID_MAX,
  798. channel: channel,
  799. middle: make([]MiddleFunc, 0),
  800. connectHostFunc: connectHostFunc,
  801. authFunc: authFunc,
  802. checkAuthFunc: checkAuthFunc,
  803. connectStatusFunc: connectStatusFunc,
  804. checkConnectOkFunc: checkConnectOkFunc,
  805. lastCleanDeadConnect: time.Now().UnixMilli(),
  806. }
  807. go h.checkProxyConnect()
  808. return h
  809. }