hub.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. package tinymq
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. // "regexp"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. "git.me9.top/git/tinymq/config"
  16. "git.me9.top/git/tinymq/conn"
  17. "git.me9.top/git/tinymq/conn/tcp2"
  18. "git.me9.top/git/tinymq/conn/ws2"
  19. )
  20. // 类似一个插座的功能,管理多个连接
  21. // 一个hub即可以是客户端,同时也可以是服务端
  22. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  23. // 截取部分字符串
  24. func subStr(str string, length int) string {
  25. if len(str) <= length {
  26. return str
  27. }
  28. return str[0:length] + "..."
  29. }
  30. type Hub struct {
  31. sync.Mutex
  32. cf *config.Config
  33. globalID uint16
  34. channel string // 本地频道信息
  35. middle []MiddleFunc // 中间件
  36. // connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
  37. lines *Mapx // 记录当前的连接,统一管理
  38. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  39. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  40. // 客户端需要用的函数
  41. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  42. authFunc AuthFunc // 获取认证信息,用于发送给对方
  43. // 服务端需要用的函数
  44. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  45. // 连接状态变化时调用的函数
  46. connectStatusFunc ConnectStatusFunc
  47. // 验证发送的数据条件是否满足 (可为空)
  48. checkConnectOkFunc CheckConnectOkFunc
  49. // 上次清理异常连接时间戳
  50. lastCleanDeadConnect int64
  51. }
  52. // 转换数据
  53. func (h *Hub) convertData(data any) (reqData []byte, err error) {
  54. switch data := data.(type) {
  55. case []byte:
  56. reqData = data
  57. case string:
  58. reqData = []byte(data)
  59. case func() ([]byte, error):
  60. reqData, err = data()
  61. if err != nil {
  62. log.Println(err.Error())
  63. return nil, err
  64. }
  65. default:
  66. if data != nil {
  67. // 自动转换数据为json格式
  68. reqData, err = json.Marshal(data)
  69. if err != nil {
  70. log.Println(err.Error())
  71. return nil, err
  72. }
  73. }
  74. }
  75. return
  76. }
  77. // 清理异常连接
  78. func (h *Hub) cleanDeadConnect() {
  79. now := time.Now().UnixMilli()
  80. expired := now - int64(h.cf.CleanDeadConnectWait)
  81. if h.lastCleanDeadConnect < expired {
  82. h.lastCleanDeadConnect = now
  83. h.lines.DeleteInvalidLines(expired)
  84. }
  85. }
  86. // 获取通讯消息ID号
  87. func (h *Hub) GetID() uint16 {
  88. h.Lock()
  89. defer h.Unlock()
  90. h.globalID++
  91. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  92. h.globalID = 1
  93. }
  94. for {
  95. // 检查是否在请求队列中存在对应的id
  96. if _, ok := h.msgCache.Load(h.globalID); ok {
  97. h.globalID++
  98. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  99. h.globalID = 1
  100. }
  101. } else {
  102. break
  103. }
  104. }
  105. return h.globalID
  106. }
  107. // 添加中间件
  108. // 如果中间件函数返回为空,表示处理完成,通过
  109. // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过
  110. func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
  111. h.middle = append(h.middle, middleFunc)
  112. }
  113. // 注册频道,其中频道为正则表达式字符串
  114. func (h *Hub) Subscribe(filter FilterFunc, cmd string, backFunc SubscribeBackFunc) (err error) {
  115. if filter == nil {
  116. return errors.New("filter function can not be nil")
  117. }
  118. reg := &SubscribeData{
  119. Filter: filter,
  120. Cmd: cmd,
  121. BackFunc: backFunc,
  122. }
  123. sub, ok := h.subscribes.Load(cmd)
  124. if ok {
  125. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  126. return
  127. }
  128. regs := make([]*SubscribeData, 1)
  129. regs[0] = reg
  130. h.subscribes.Store(cmd, regs)
  131. return
  132. }
  133. // 遍历频道列表
  134. // 如果 fn 返回 false,则 range 停止迭代
  135. func (h *Hub) ConnectRange(fn func(id int, line *Line) bool) {
  136. h.lines.Range(fn)
  137. }
  138. // 获取当前在线的数量
  139. func (h *Hub) ConnectNum() int {
  140. var count int
  141. h.lines.Range(func(id int, line *Line) bool {
  142. if line.state == Connected {
  143. count++
  144. }
  145. return true
  146. })
  147. return count
  148. }
  149. // 获取所有的在线连接频道
  150. func (h *Hub) AllChannel() []string {
  151. cs := make([]string, 0)
  152. h.lines.Range(func(id int, line *Line) bool {
  153. if line.state == Connected {
  154. cs = append(cs, line.channel)
  155. }
  156. return true
  157. })
  158. return cs
  159. }
  160. // 获取所有连接频道和连接时长
  161. // 为了避免定义数据结构麻烦,采用|隔开, 频道名|连接开始时间
  162. func (h *Hub) AllChannelWithStarted() []string {
  163. cs := make([]string, 0)
  164. h.lines.Range(func(id int, line *Line) bool {
  165. if line.state == Connected {
  166. cs = append(cs, fmt.Sprintf("%s|%d", line.channel, line.started.UnixMilli()))
  167. }
  168. return true
  169. })
  170. return cs
  171. }
  172. // 获取频道并通过函数过滤,如果返回 false 将终止
  173. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  174. h.lines.Range(func(id int, line *Line) bool {
  175. if line.state == Connected {
  176. return fn(line.channel)
  177. }
  178. return true
  179. })
  180. }
  181. // 从 channel 获取连接
  182. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  183. h.lines.Range(func(id int, l *Line) bool {
  184. if l.channel == channel {
  185. line = l
  186. return false
  187. }
  188. return true
  189. })
  190. return
  191. }
  192. // 返回请求结果
  193. func (h *Hub) outResponse(response *ResponseData) {
  194. defer recover() //避免管道已经关闭而引起panic
  195. id := response.Id
  196. t, ok := h.msgCache.Load(id)
  197. if ok {
  198. // 删除数据缓存
  199. h.msgCache.Delete(id)
  200. gm := t.(*GetMsg)
  201. // 停止定时器
  202. if !gm.timer.Stop() {
  203. select {
  204. case <-gm.timer.C:
  205. default:
  206. }
  207. }
  208. // 回应数据到上层
  209. gm.out <- response
  210. }
  211. }
  212. // 发送数据到网络接口
  213. // 返回发送的数量
  214. func (h *Hub) sendRequest(gd *GetData) (count int) {
  215. outData, err := h.convertData(gd.Data)
  216. if err != nil {
  217. log.Println(err)
  218. return 0
  219. }
  220. doit := func(_ int, line *Line) bool {
  221. // 检查连接是否OK
  222. if line.state != Connected {
  223. return true
  224. }
  225. // 验证连接是否达到发送数据的要求
  226. if h.checkConnectOkFunc != nil && !h.checkConnectOkFunc(line, gd) {
  227. return true
  228. }
  229. if gd.Filter(line) {
  230. var id uint16
  231. if gd.backchan != nil {
  232. id = h.GetID()
  233. timeout := gd.Timeout
  234. if timeout <= 0 {
  235. timeout = h.cf.WriteWait
  236. }
  237. fn := func(id uint16, conn *Line) func() {
  238. return func() {
  239. go h.outResponse(&ResponseData{
  240. Id: id,
  241. State: config.GET_TIMEOUT,
  242. Data: fmt.Appendf(nil, "[%s] %s %s", config.GET_TIMEOUT_MSG, conn.channel, gd.Cmd),
  243. conn: conn,
  244. })
  245. // 检查是否已经很久时间没有使用连接了
  246. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) {
  247. // 超时关闭当前的连接
  248. log.Println("get message timeout", conn.channel)
  249. // 有可能连接出现问题,断开并重新连接
  250. conn.Close(false)
  251. return
  252. }
  253. }
  254. }(id, line)
  255. // 将要发送的请求缓存
  256. gm := &GetMsg{
  257. out: gd.backchan,
  258. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  259. }
  260. h.msgCache.Store(id, gm)
  261. }
  262. // 组织数据并发送到Connect
  263. line.sendRequest <- &RequestData{
  264. Id: id,
  265. Cmd: gd.Cmd,
  266. Data: outData,
  267. timeout: gd.Timeout,
  268. backchan: gd.backchan,
  269. conn: line,
  270. }
  271. if h.cf.PrintMsg {
  272. log.Println("[SEND]->", id, line.channel, "["+gd.Cmd+"]", subStr(string(outData), 200))
  273. }
  274. count++
  275. if gd.Max > 0 && count >= gd.Max {
  276. return false
  277. }
  278. }
  279. return true
  280. }
  281. // 如果没有发送到消息,延时重连直到超时
  282. for i := 0; i <= gd.Timeout; i += 500 {
  283. if gd.Rand {
  284. h.lines.RandRange(doit, i == 0)
  285. } else {
  286. h.lines.Range(doit)
  287. }
  288. if count > 0 {
  289. break
  290. }
  291. time.Sleep(time.Millisecond * 500)
  292. }
  293. return
  294. }
  295. // 执行网络发送过来的命令
  296. func (h *Hub) requestFromNet(request *RequestData) {
  297. cmd := request.Cmd
  298. channel := request.conn.channel
  299. if h.cf.PrintMsg {
  300. log.Println("[REQU]<-", request.Id, channel, "["+cmd+"]", subStr(string(request.Data), 200))
  301. }
  302. // 执行中间件
  303. for _, mdFunc := range h.middle {
  304. rsp := mdFunc(request)
  305. if rsp != nil {
  306. // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理
  307. if rsp.State == config.NEXT_MIDDLE {
  308. continue
  309. }
  310. // 返回消息
  311. if request.Id != 0 {
  312. rsp.Id = request.Id
  313. request.conn.sendResponse <- rsp
  314. }
  315. return
  316. } else {
  317. break
  318. }
  319. }
  320. sub, ok := h.subscribes.Load(cmd)
  321. if ok {
  322. subs := sub.([]*SubscribeData)
  323. // 倒序查找是为了新增的频道响应函数优先执行
  324. for i := len(subs) - 1; i >= 0; i-- {
  325. rg := subs[i]
  326. if rg.Filter(request.conn) {
  327. state, data := rg.BackFunc(request)
  328. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  329. if state == config.NEXT_SUBSCRIBE {
  330. continue
  331. }
  332. var byteData []byte
  333. switch data := data.(type) {
  334. case []byte:
  335. byteData = data
  336. case string:
  337. byteData = []byte(data)
  338. default:
  339. if data != nil {
  340. // 自动转换数据为json格式
  341. var err error
  342. byteData, err = json.Marshal(data)
  343. if err != nil {
  344. log.Println(err.Error())
  345. state = config.CONVERT_FAILED
  346. byteData = fmt.Appendf(nil, "[%s] %s %s", config.CONVERT_FAILED_MSG, request.conn.channel, request.Cmd)
  347. }
  348. }
  349. }
  350. // 如果id为0表示不需要回应
  351. if request.Id != 0 {
  352. request.conn.sendResponse <- &ResponseData{
  353. Id: request.Id,
  354. State: state,
  355. Data: byteData,
  356. }
  357. if h.cf.PrintMsg {
  358. log.Println("[RESP]->", request.Id, channel, "["+cmd+"]", state, subStr(string(byteData), 200))
  359. }
  360. }
  361. return
  362. }
  363. }
  364. }
  365. log.Println("[not match command]", channel, cmd)
  366. // 返回没有匹配的消息
  367. request.conn.sendResponse <- &ResponseData{
  368. Id: request.Id,
  369. State: config.NO_MATCH,
  370. Data: fmt.Appendf(nil, "[%s] %s %s", config.NO_MATCH_MSG, channel, cmd),
  371. }
  372. }
  373. // 请求频道并获取数据,采用回调的方式返回结果
  374. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  375. // 如果 backFunc 返回为 false 则提前结束
  376. // 最大数量和超时时间如果为0的话表示使用默认值
  377. func (h *Hub) GetWithStruct(gd *GetData, backFunc GetBackFunc) (count int) {
  378. if gd.Filter == nil {
  379. return 0
  380. }
  381. if gd.Timeout <= 0 {
  382. gd.Timeout = h.cf.WriteWait
  383. }
  384. if gd.backchan == nil {
  385. gd.backchan = make(chan *ResponseData, 32)
  386. }
  387. // // 排除空频道
  388. // if filter == nil {
  389. // return 0
  390. // }
  391. // if timeout <= 0 {
  392. // timeout = h.cf.WriteWait
  393. // }
  394. // gd := &GetData{
  395. // Filter: filter,
  396. // Cmd: cmd,
  397. // Data: data,
  398. // Max: max,
  399. // Timeout: timeout,
  400. // backchan: make(chan *ResponseData, 32),
  401. // }
  402. sendMax := h.sendRequest(gd)
  403. if sendMax <= 0 {
  404. return 0
  405. }
  406. // 避免出现异常时线程无法退出
  407. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  408. defer func() {
  409. if !timer.Stop() {
  410. select {
  411. case <-timer.C:
  412. default:
  413. }
  414. }
  415. close(gd.backchan)
  416. }()
  417. for {
  418. select {
  419. case rp := <-gd.backchan:
  420. if rp == nil || rp.conn == nil {
  421. // 可能是已经退出了
  422. return
  423. }
  424. ch := rp.conn.channel
  425. if h.cf.PrintMsg {
  426. log.Println("[RECV]<-", rp.Id, ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  427. }
  428. count++
  429. // 如果这里返回为false这跳出循环
  430. if backFunc != nil && !backFunc(rp) {
  431. return
  432. }
  433. if count >= sendMax {
  434. return
  435. }
  436. case <-timer.C:
  437. return
  438. }
  439. }
  440. // return
  441. }
  442. // 请求频道并获取数据,采用回调的方式返回结果
  443. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  444. // 如果 backFunc 返回为 false 则提前结束
  445. func (h *Hub) Get(filter FilterFunc, cmd string, data any, backFunc GetBackFunc) (count int) {
  446. return h.GetWithStruct(&GetData{
  447. Filter: filter,
  448. Cmd: cmd,
  449. Data: data,
  450. }, backFunc)
  451. }
  452. // 获取一个数据,阻塞等待到超时间隔
  453. func (h *Hub) GetOneWithStruct(gd *GetData) (response *ResponseData) {
  454. if gd.Filter == nil {
  455. return &ResponseData{
  456. State: config.CONNECT_NO_MATCH,
  457. Data: fmt.Appendf(nil, "[%s] not filter function", config.CONNECT_NO_MATCH_MSG),
  458. }
  459. }
  460. gd.Max = 1
  461. h.GetWithStruct(gd, func(rp *ResponseData) (ok bool) {
  462. response = rp
  463. return false
  464. })
  465. if response == nil {
  466. return &ResponseData{
  467. State: config.CONNECT_NO_MATCH,
  468. Data: fmt.Appendf(nil, "[%s] %s", config.CONNECT_NO_MATCH_MSG, gd.Cmd),
  469. }
  470. }
  471. return
  472. }
  473. // 只获取一个频道的数据,阻塞等待到默认超时间隔
  474. // 如果没有结果将返回 NO_MATCH
  475. func (h *Hub) GetOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  476. return h.GetOneWithStruct(&GetData{
  477. Filter: filter,
  478. Cmd: cmd,
  479. Data: data,
  480. Max: 1,
  481. })
  482. }
  483. func (h *Hub) GetRandOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  484. return h.GetOneWithStruct(&GetData{
  485. Filter: filter,
  486. Cmd: cmd,
  487. Data: data,
  488. Max: 1,
  489. Rand: true,
  490. })
  491. }
  492. // 只获取一个频道的数据,阻塞等待到指定超时间隔
  493. // 如果没有结果将返回 NO_MATCH
  494. func (h *Hub) GetOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  495. return h.GetOneWithStruct(&GetData{
  496. Filter: filter,
  497. Cmd: cmd,
  498. Data: data,
  499. Max: 1,
  500. Timeout: timeout,
  501. })
  502. }
  503. func (h *Hub) GetRandOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  504. return h.GetOneWithStruct(&GetData{
  505. Filter: filter,
  506. Cmd: cmd,
  507. Data: data,
  508. Max: 1,
  509. Timeout: timeout,
  510. Rand: true,
  511. })
  512. }
  513. // 推送消息出去,不需要返回数据
  514. func (h *Hub) Push(filter FilterFunc, cmd string, data any) {
  515. h.PushWithMax(filter, cmd, data, 0)
  516. }
  517. // 推送最大对应数量的消息出去,不需要返回数据
  518. func (h *Hub) PushWithMax(filter FilterFunc, cmd string, data any, max int) {
  519. // 排除空频道
  520. if filter == nil {
  521. return
  522. }
  523. gd := &GetData{
  524. Filter: filter,
  525. Cmd: cmd,
  526. Data: data,
  527. Max: max,
  528. Timeout: h.cf.WriteWait,
  529. backchan: nil,
  530. }
  531. h.sendRequest(gd)
  532. }
  533. // 增加连接
  534. func (h *Hub) addLine(line *Line) {
  535. if h.lines.Exist(line) {
  536. log.Println("connect have exist")
  537. // 连接已经存在,直接返回
  538. return
  539. }
  540. h.lines.Store(line)
  541. // // 检查是否有相同的channel,如果有的话将其关闭删除
  542. // channel := line.channel
  543. // h.connects.Range(func(key, _ any) bool {
  544. // conn := key.(*Line)
  545. // // 删除超时的连接
  546. // if conn.state != Connected && conn.host == nil && time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*5*int(time.Millisecond)) {
  547. // h.connects.Delete(key)
  548. // return true
  549. // }
  550. // if conn.channel == channel {
  551. // conn.Close(true)
  552. // h.connects.Delete(key)
  553. // return false
  554. // }
  555. // return true
  556. // })
  557. // h.connects.Store(line, true)
  558. }
  559. // 删除连接
  560. func (h *Hub) removeLine(line *Line) {
  561. line.Close(true)
  562. h.lines.Delete(line)
  563. }
  564. // 获取指定连接的连接持续时间
  565. func (h *Hub) ConnectDuration(line *Line) time.Duration {
  566. return time.Since(line.started)
  567. }
  568. // 绑定端口,建立服务
  569. // 需要程序运行时调用
  570. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  571. doConnectFunc := func(conn conn.Connect) {
  572. proto, version, channel, auth, err := conn.ReadAuthInfo()
  573. if err != nil {
  574. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  575. conn.Close()
  576. return
  577. }
  578. if version != info.Version || proto != info.Proto {
  579. log.Println("wrong version or protocol: ", version, proto)
  580. conn.Close()
  581. return
  582. }
  583. // 检查验证是否合法
  584. if !h.checkAuthFunc(proto, version, channel, auth) {
  585. conn.Close()
  586. return
  587. }
  588. // 发送频道信息
  589. if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
  590. log.Println("[WriteAuthInfo ERROR]", err)
  591. conn.Close()
  592. return
  593. }
  594. // 将连接加入现有连接中
  595. done := false
  596. h.lines.Range(func(id int, line *Line) bool {
  597. if line.state == Disconnected && line.host == nil && line.ChannelEqualWithoutPrefix(channel) {
  598. line.Start(conn, nil)
  599. done = true
  600. return false
  601. }
  602. return true
  603. })
  604. // 新建一个连接
  605. if !done {
  606. line := NewConnect(h.cf, h, channel, conn, nil)
  607. h.addLine(line)
  608. }
  609. }
  610. if info.Version == ws2.VERSION && info.Proto == ws2.PROTO {
  611. bind := ""
  612. if info.Bind != "" {
  613. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  614. }
  615. return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  616. } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO {
  617. return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  618. }
  619. return errors.New("not connect protocol and version found")
  620. }
  621. // 新建一个连接,不同的连接协议由底层自己选择
  622. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  623. func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err error) {
  624. // 检查当前channel是否已经存在
  625. if !force {
  626. line := h.ChannelToLine(channel)
  627. if line != nil && line.state == Connected {
  628. err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  629. return
  630. }
  631. }
  632. if host == nil {
  633. // 获取服务地址等信息
  634. host, err = h.connectHostFunc(channel, Both)
  635. if err != nil {
  636. return err
  637. }
  638. }
  639. var conn conn.Connect
  640. var runProto string
  641. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  642. // 添加定时器
  643. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
  644. defer cancel()
  645. taskCh := make(chan bool)
  646. done := false
  647. go func() {
  648. if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
  649. runProto = ws2.PROTO
  650. conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
  651. } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
  652. runProto = tcp2.PROTO
  653. conn, err = tcp2.Dial(h.cf, addr, host.Hash)
  654. } else {
  655. err = fmt.Errorf("not correct protocol and version found in: %+v", host)
  656. }
  657. if done {
  658. if err != nil {
  659. log.Println("[Dial ERROR]", err)
  660. }
  661. if conn != nil {
  662. conn.Close()
  663. }
  664. } else {
  665. taskCh <- err == nil
  666. }
  667. }()
  668. select {
  669. case ok := <-taskCh:
  670. cancel()
  671. if !ok || err != nil || conn == nil {
  672. log.Println("[Client ERROR]", host.Proto, err)
  673. host.Errors++
  674. host.Updated = time.Now()
  675. if err == nil {
  676. err = errors.New("unknown error")
  677. }
  678. return err
  679. }
  680. case <-ctx.Done():
  681. done = true
  682. return errors.New("timeout")
  683. }
  684. // 发送验证信息
  685. if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
  686. log.Println("[WriteAuthInfo ERROR]", err)
  687. conn.Close()
  688. host.Errors++
  689. host.Updated = time.Now()
  690. return err
  691. }
  692. // 接收频道信息
  693. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  694. if err != nil {
  695. log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
  696. conn.Close()
  697. host.Errors++
  698. host.Updated = time.Now()
  699. return err
  700. }
  701. // 检查版本和协议是否一致
  702. if version != host.Version || proto != runProto {
  703. err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
  704. log.Println(err)
  705. conn.Close()
  706. host.Errors++
  707. host.Updated = time.Now()
  708. return err
  709. }
  710. // 检查频道名称是否匹配
  711. if !strings.Contains(channel2, channel) {
  712. err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
  713. log.Println(err)
  714. conn.Close()
  715. host.Errors++
  716. host.Updated = time.Now()
  717. return err
  718. }
  719. // 检查验证是否合法
  720. if !h.checkAuthFunc(proto, version, channel, auth) {
  721. err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  722. log.Println(err)
  723. conn.Close()
  724. host.Errors++
  725. host.Updated = time.Now()
  726. return err
  727. }
  728. // 更新服务主机信息
  729. host.Errors = 0
  730. host.Updated = time.Now()
  731. // 将连接加入现有连接中
  732. done = false
  733. h.lines.Range(func(id int, line *Line) bool {
  734. if line.channel == channel {
  735. if line.state == Connected {
  736. if force {
  737. line.Close(true)
  738. } else {
  739. err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
  740. log.Println(err)
  741. return false
  742. }
  743. }
  744. line.Start(conn, host)
  745. done = true
  746. return false
  747. }
  748. return true
  749. })
  750. if err != nil {
  751. return err
  752. }
  753. // 新建一个连接
  754. if !done {
  755. line := NewConnect(h.cf, h, channel, conn, host)
  756. h.addLine(line)
  757. }
  758. return nil
  759. }
  760. // 重试方式连接服务
  761. // 将会一直阻塞直到连接成功
  762. func (h *Hub) ConnectToServerX(channel string, force bool, host *HostInfo) {
  763. if host == nil {
  764. host, _ = h.connectHostFunc(channel, Direct)
  765. }
  766. for {
  767. err := h.ConnectToServer(channel, force, host)
  768. if err == nil {
  769. return
  770. }
  771. log.Println("[ConnectToServer ERROR, try it again]", channel, host, err)
  772. host = nil
  773. // 产生一个随机数避免刹间重连过载
  774. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  775. time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
  776. }
  777. }
  778. // 检测处理代理连接
  779. func (h *Hub) checkProxyConnect() {
  780. if h.cf.ProxyTimeout <= 0 {
  781. return
  782. }
  783. proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond)))
  784. for {
  785. <-proxyTicker.C
  786. now := time.Now().UnixMilli()
  787. h.lines.Range(func(id int, line *Line) bool {
  788. if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) {
  789. host, err := h.connectHostFunc(line.channel, Direct)
  790. if err != nil {
  791. log.Println("[checkProxyConnect connectHostFunc ERROR]", err)
  792. return false
  793. }
  794. err = h.ConnectToServer(line.channel, true, host)
  795. if err != nil {
  796. log.Println("[checkProxyConnect ConnectToServer WARNING]", err)
  797. }
  798. }
  799. return true
  800. })
  801. }
  802. }
  803. // 建立一个集线器
  804. // connectFunc 用于监听连接状态的函数,可以为nil
  805. func NewHub(
  806. cf *config.Config,
  807. channel string,
  808. // 客户端需要用的函数
  809. connectHostFunc ConnectHostFunc,
  810. authFunc AuthFunc,
  811. // 服务端需要用的函数
  812. checkAuthFunc CheckAuthFunc,
  813. // 连接状态变化时调用的函数
  814. connectStatusFunc ConnectStatusFunc,
  815. // 验证发送的数据条件是否满足 (可为空)
  816. checkConnectOkFunc CheckConnectOkFunc,
  817. ) (h *Hub) {
  818. h = &Hub{
  819. cf: cf,
  820. globalID: uint16(time.Now().UnixNano()) % config.ID_MAX,
  821. channel: channel,
  822. middle: make([]MiddleFunc, 0),
  823. lines: NewMapx(),
  824. connectHostFunc: connectHostFunc,
  825. authFunc: authFunc,
  826. checkAuthFunc: checkAuthFunc,
  827. connectStatusFunc: connectStatusFunc,
  828. checkConnectOkFunc: checkConnectOkFunc,
  829. lastCleanDeadConnect: time.Now().UnixMilli(),
  830. }
  831. go h.checkProxyConnect()
  832. return h
  833. }