hub.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. package tinymq
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. // "regexp"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. "git.me9.top/git/tinymq/config"
  16. "git.me9.top/git/tinymq/conn"
  17. "git.me9.top/git/tinymq/conn/tcp2"
  18. "git.me9.top/git/tinymq/conn/ws2"
  19. )
  20. // 类似一个插座的功能,管理多个连接
  21. // 一个hub即可以是客户端,同时也可以是服务端
  22. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  23. // 截取部分字符串
  24. func subStr(str string, length int) string {
  25. if len(str) <= length {
  26. return str
  27. }
  28. return str[0:length] + "..."
  29. }
  30. type Hub struct {
  31. sync.Mutex
  32. cf *config.Config
  33. globalID uint16
  34. channel string // 本地频道信息
  35. middle []MiddleFunc // 中间件
  36. // connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
  37. lines *Mapx // 记录当前的连接,统一管理
  38. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  39. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  40. // 客户端需要用的函数(服务端可为空)
  41. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  42. // 返回认证信息,发送到对方
  43. authFunc AuthFunc // 获取认证信息,用于发送给对方
  44. // 核对发送过来的认证信息
  45. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  46. // 连接状态变化时调用的函数
  47. connectStatusFunc ConnectStatusFunc
  48. // 验证发送数据的条件是否满足 (可为空)
  49. checkConnectOkFunc CheckConnectOkFunc
  50. // 上次清理异常连接时间戳
  51. lastCleanDeadConnect int64
  52. }
  53. // 转换数据
  54. func (h *Hub) convertData(data any) (reqData []byte, err error) {
  55. switch data := data.(type) {
  56. case []byte:
  57. reqData = data
  58. case string:
  59. reqData = []byte(data)
  60. case func() ([]byte, error):
  61. reqData, err = data()
  62. if err != nil {
  63. log.Println(err.Error())
  64. return nil, err
  65. }
  66. default:
  67. if data != nil {
  68. // 自动转换数据为json格式
  69. reqData, err = json.Marshal(data)
  70. if err != nil {
  71. log.Println(err.Error())
  72. return nil, err
  73. }
  74. }
  75. }
  76. return
  77. }
  78. // 清理异常连接
  79. func (h *Hub) cleanDeadConnect() {
  80. now := time.Now().UnixMilli()
  81. expired := now - int64(h.cf.CleanDeadConnectWait)
  82. if h.lastCleanDeadConnect < expired {
  83. h.lastCleanDeadConnect = now
  84. h.lines.DeleteInvalidLines(expired)
  85. }
  86. }
  87. // 获取通讯消息ID号
  88. func (h *Hub) GetID() uint16 {
  89. h.Lock()
  90. defer h.Unlock()
  91. h.globalID++
  92. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  93. h.globalID = 1
  94. }
  95. for {
  96. // 检查是否在请求队列中存在对应的id
  97. if _, ok := h.msgCache.Load(h.globalID); ok {
  98. h.globalID++
  99. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  100. h.globalID = 1
  101. }
  102. } else {
  103. break
  104. }
  105. }
  106. return h.globalID
  107. }
  108. // 添加中间件
  109. // 如果中间件函数返回为空,表示处理完成,通过
  110. // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过
  111. func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
  112. h.middle = append(h.middle, middleFunc)
  113. }
  114. // 注册频道,其中频道为正则表达式字符串
  115. func (h *Hub) Subscribe(filter FilterFunc, cmd string, backFunc SubscribeBackFunc) (err error) {
  116. if filter == nil {
  117. return errors.New("filter function can not be nil")
  118. }
  119. reg := &SubscribeData{
  120. Filter: filter,
  121. Cmd: cmd,
  122. BackFunc: backFunc,
  123. }
  124. sub, ok := h.subscribes.Load(cmd)
  125. if ok {
  126. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  127. return
  128. }
  129. regs := make([]*SubscribeData, 1)
  130. regs[0] = reg
  131. h.subscribes.Store(cmd, regs)
  132. return
  133. }
  134. // 遍历频道列表
  135. // 如果 fn 返回 false,则 range 停止迭代
  136. func (h *Hub) ConnectRange(fn func(id int, line *Line) bool) {
  137. h.lines.Range(fn)
  138. }
  139. // 获取当前在线的数量
  140. func (h *Hub) ConnectNum() int {
  141. var count int
  142. h.lines.Range(func(id int, line *Line) bool {
  143. if line.state == Connected {
  144. count++
  145. }
  146. return true
  147. })
  148. return count
  149. }
  150. // 获取所有的在线连接频道
  151. func (h *Hub) AllChannel() []string {
  152. cs := make([]string, 0)
  153. h.lines.Range(func(id int, line *Line) bool {
  154. if line.state == Connected {
  155. cs = append(cs, line.channel)
  156. }
  157. return true
  158. })
  159. return cs
  160. }
  161. // 获取所有连接频道和连接时长
  162. // 为了避免定义数据结构麻烦,采用|隔开, 频道名|连接开始时间
  163. func (h *Hub) AllChannelWithStarted() []string {
  164. cs := make([]string, 0)
  165. h.lines.Range(func(id int, line *Line) bool {
  166. if line.state == Connected {
  167. cs = append(cs, fmt.Sprintf("%s|%d", line.channel, line.started.UnixMilli()))
  168. }
  169. return true
  170. })
  171. return cs
  172. }
  173. // 获取频道并通过函数过滤,如果返回 false 将终止
  174. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  175. h.lines.Range(func(id int, line *Line) bool {
  176. if line.state == Connected {
  177. return fn(line.channel)
  178. }
  179. return true
  180. })
  181. }
  182. // 从 channel 获取连接
  183. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  184. h.lines.Range(func(id int, l *Line) bool {
  185. if l.channel == channel {
  186. line = l
  187. return false
  188. }
  189. return true
  190. })
  191. return
  192. }
  193. // 返回请求结果
  194. func (h *Hub) outResponse(response *ResponseData) {
  195. defer recover() //避免管道已经关闭而引起panic
  196. id := response.Id
  197. t, ok := h.msgCache.Load(id)
  198. if ok {
  199. // 删除数据缓存
  200. h.msgCache.Delete(id)
  201. gm := t.(*GetMsg)
  202. // 停止定时器
  203. if !gm.timer.Stop() {
  204. select {
  205. case <-gm.timer.C:
  206. default:
  207. }
  208. }
  209. // 回应数据到上层
  210. gm.out <- response
  211. }
  212. }
  213. // 发送数据到网络接口
  214. // 返回发送的数量
  215. func (h *Hub) sendRequest(gd *GetData) (count int) {
  216. outData, err := h.convertData(gd.Data)
  217. if err != nil {
  218. log.Println(err)
  219. return 0
  220. }
  221. doit := func(_ int, line *Line) bool {
  222. // 检查连接是否OK
  223. if line.state != Connected {
  224. return true
  225. }
  226. // 验证连接是否达到发送数据的要求
  227. if h.checkConnectOkFunc != nil && !h.checkConnectOkFunc(line, gd) {
  228. return true
  229. }
  230. if gd.Filter(line) {
  231. var id uint16
  232. if gd.backchan != nil {
  233. id = h.GetID()
  234. timeout := gd.Timeout
  235. if timeout <= 0 {
  236. timeout = h.cf.WriteWait
  237. }
  238. fn := func(id uint16, conn *Line) func() {
  239. return func() {
  240. go h.outResponse(&ResponseData{
  241. Id: id,
  242. State: config.GET_TIMEOUT,
  243. Data: fmt.Appendf(nil, "[%s] %s %s", config.IdMsg(config.GET_TIMEOUT), conn.channel, gd.Cmd),
  244. conn: conn,
  245. })
  246. // 检查是否已经很久时间没有使用连接了
  247. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) {
  248. // 超时关闭当前的连接
  249. log.Println("get message timeout", conn.channel)
  250. // 有可能连接出现问题,断开并重新连接
  251. conn.Close(false)
  252. return
  253. }
  254. }
  255. }(id, line)
  256. // 将要发送的请求缓存
  257. gm := &GetMsg{
  258. out: gd.backchan,
  259. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  260. }
  261. h.msgCache.Store(id, gm)
  262. }
  263. // 组织数据并发送到Connect
  264. line.sendRequest <- &RequestData{
  265. Id: id,
  266. Cmd: gd.Cmd,
  267. Data: outData,
  268. timeout: gd.Timeout,
  269. backchan: gd.backchan,
  270. conn: line,
  271. }
  272. if h.cf.PrintMsg {
  273. log.Println("[SEND]->", id, line.channel, "["+gd.Cmd+"]", subStr(string(outData), 200))
  274. }
  275. count++
  276. if gd.Max > 0 && count >= gd.Max {
  277. return false
  278. }
  279. }
  280. return true
  281. }
  282. // 如果没有发送到消息,延时重连直到超时
  283. for i := 0; i <= gd.Timeout; i += 500 {
  284. if gd.Rand {
  285. h.lines.RandRange(doit, i == 0)
  286. } else {
  287. h.lines.Range(doit)
  288. }
  289. if count > 0 {
  290. break
  291. }
  292. time.Sleep(time.Millisecond * 500)
  293. }
  294. return
  295. }
  296. // 执行网络发送过来的命令
  297. func (h *Hub) requestFromNet(request *RequestData) {
  298. cmd := request.Cmd
  299. channel := request.conn.channel
  300. if h.cf.PrintMsg {
  301. log.Println("[REQU]<-", request.Id, channel, "["+cmd+"]", subStr(string(request.Data), 200))
  302. }
  303. // 执行中间件
  304. for _, mdFunc := range h.middle {
  305. rsp := mdFunc(request)
  306. if rsp != nil {
  307. // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理
  308. if rsp.State == config.NEXT_MIDDLE {
  309. continue
  310. }
  311. // 返回消息
  312. if request.Id != 0 {
  313. rsp.Id = request.Id
  314. request.conn.sendResponse <- rsp
  315. }
  316. return
  317. } else {
  318. break
  319. }
  320. }
  321. sub, ok := h.subscribes.Load(cmd)
  322. if ok {
  323. subs := sub.([]*SubscribeData)
  324. // 倒序查找是为了新增的频道响应函数优先执行
  325. for i := len(subs) - 1; i >= 0; i-- {
  326. rg := subs[i]
  327. if rg.Filter(request.conn) {
  328. state, data := rg.BackFunc(request)
  329. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  330. if state == config.NEXT_SUBSCRIBE {
  331. continue
  332. }
  333. var byteData []byte
  334. switch data := data.(type) {
  335. case []byte:
  336. byteData = data
  337. case string:
  338. byteData = []byte(data)
  339. default:
  340. if data != nil {
  341. // 自动转换数据为json格式
  342. var err error
  343. byteData, err = json.Marshal(data)
  344. if err != nil {
  345. log.Println(err.Error())
  346. state = config.CONVERT_FAILED
  347. byteData = fmt.Appendf(nil, "[%s] %s %s", config.IdMsg(config.CONVERT_FAILED), request.conn.channel, request.Cmd)
  348. }
  349. }
  350. }
  351. // 如果id为0表示不需要回应
  352. if request.Id != 0 {
  353. request.conn.sendResponse <- &ResponseData{
  354. Id: request.Id,
  355. State: state,
  356. Data: byteData,
  357. }
  358. if h.cf.PrintMsg {
  359. log.Println("[RESP]->", request.Id, channel, "["+cmd+"]", state, subStr(string(byteData), 200))
  360. }
  361. }
  362. return
  363. }
  364. }
  365. }
  366. log.Println("[not match command]", channel, cmd)
  367. // 返回没有匹配的消息
  368. request.conn.sendResponse <- &ResponseData{
  369. Id: request.Id,
  370. State: config.NO_MATCH,
  371. Data: fmt.Appendf(nil, "[%s] %s %s", config.IdMsg(config.NO_MATCH), channel, cmd),
  372. }
  373. }
  374. // 请求频道并获取数据,采用回调的方式返回结果
  375. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  376. // 如果 backFunc 返回为 false 则提前结束
  377. // 最大数量和超时时间如果为0的话表示使用默认值
  378. func (h *Hub) GetWithStruct(gd *GetData, backFunc GetBackFunc) (count int) {
  379. if gd.Filter == nil {
  380. return 0
  381. }
  382. if gd.Timeout <= 0 {
  383. gd.Timeout = h.cf.WriteWait
  384. }
  385. if gd.backchan == nil {
  386. gd.backchan = make(chan *ResponseData, 32)
  387. }
  388. sendMax := h.sendRequest(gd)
  389. if sendMax <= 0 {
  390. return 0
  391. }
  392. // 避免出现异常时线程无法退出
  393. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  394. defer func() {
  395. if !timer.Stop() {
  396. select {
  397. case <-timer.C:
  398. default:
  399. }
  400. }
  401. close(gd.backchan)
  402. }()
  403. for {
  404. select {
  405. case rp := <-gd.backchan:
  406. if rp == nil || rp.conn == nil {
  407. // 可能是已经退出了
  408. return
  409. }
  410. ch := rp.conn.channel
  411. if h.cf.PrintMsg {
  412. log.Println("[RECV]<-", rp.Id, ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  413. }
  414. count++
  415. // 如果这里返回为false这跳出循环
  416. if backFunc != nil && !backFunc(rp) {
  417. return
  418. }
  419. if count >= sendMax {
  420. return
  421. }
  422. case <-timer.C:
  423. return
  424. }
  425. }
  426. // return
  427. }
  428. // 请求频道并获取数据,采用回调的方式返回结果
  429. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  430. // 如果 backFunc 返回为 false 则提前结束
  431. func (h *Hub) Get(filter FilterFunc, cmd string, data any, backFunc GetBackFunc) (count int) {
  432. return h.GetWithStruct(&GetData{
  433. Filter: filter,
  434. Cmd: cmd,
  435. Data: data,
  436. }, backFunc)
  437. }
  438. // 获取一个数据,阻塞等待到超时间隔
  439. func (h *Hub) GetOneWithStruct(gd *GetData) (response *ResponseData) {
  440. if gd.Filter == nil {
  441. return &ResponseData{
  442. State: config.CONNECT_NO_MATCH,
  443. Data: fmt.Appendf(nil, "[%s] not filter function", config.IdMsg(config.CONNECT_NO_MATCH)),
  444. }
  445. }
  446. gd.Max = 1
  447. h.GetWithStruct(gd, func(rp *ResponseData) (ok bool) {
  448. response = rp
  449. return false
  450. })
  451. if response == nil {
  452. return &ResponseData{
  453. State: config.CONNECT_NO_MATCH,
  454. Data: fmt.Appendf(nil, "[%s] %s", config.IdMsg(config.CONNECT_NO_MATCH), gd.Cmd),
  455. }
  456. }
  457. return
  458. }
  459. // 只获取一个频道的数据,阻塞等待到默认超时间隔
  460. // 如果没有结果将返回 NO_MATCH
  461. func (h *Hub) GetOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  462. return h.GetOneWithStruct(&GetData{
  463. Filter: filter,
  464. Cmd: cmd,
  465. Data: data,
  466. Max: 1,
  467. })
  468. }
  469. func (h *Hub) GetRandOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  470. return h.GetOneWithStruct(&GetData{
  471. Filter: filter,
  472. Cmd: cmd,
  473. Data: data,
  474. Max: 1,
  475. Rand: true,
  476. })
  477. }
  478. // 只获取一个频道的数据,阻塞等待到指定超时间隔
  479. // 如果没有结果将返回 NO_MATCH
  480. func (h *Hub) GetOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  481. return h.GetOneWithStruct(&GetData{
  482. Filter: filter,
  483. Cmd: cmd,
  484. Data: data,
  485. Max: 1,
  486. Timeout: timeout,
  487. })
  488. }
  489. func (h *Hub) GetRandOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  490. return h.GetOneWithStruct(&GetData{
  491. Filter: filter,
  492. Cmd: cmd,
  493. Data: data,
  494. Max: 1,
  495. Timeout: timeout,
  496. Rand: true,
  497. })
  498. }
  499. // 推送消息出去,不需要返回数据
  500. func (h *Hub) Push(filter FilterFunc, cmd string, data any) {
  501. h.PushWithMax(filter, cmd, data, 0)
  502. }
  503. // 推送最大对应数量的消息出去,不需要返回数据
  504. func (h *Hub) PushWithMax(filter FilterFunc, cmd string, data any, max int) {
  505. // 排除空频道
  506. if filter == nil {
  507. return
  508. }
  509. gd := &GetData{
  510. Filter: filter,
  511. Cmd: cmd,
  512. Data: data,
  513. Max: max,
  514. Timeout: h.cf.WriteWait,
  515. backchan: nil,
  516. }
  517. h.sendRequest(gd)
  518. }
  519. // 增加连接
  520. func (h *Hub) addLine(line *Line) {
  521. if h.lines.Exist(line) {
  522. log.Println("connect have exist")
  523. // 连接已经存在,直接返回
  524. return
  525. }
  526. h.lines.Store(line)
  527. }
  528. // 删除连接
  529. func (h *Hub) removeLine(line *Line) {
  530. line.Close(true)
  531. h.lines.Delete(line)
  532. }
  533. // 获取指定连接的连接持续时间
  534. func (h *Hub) ConnectDuration(line *Line) time.Duration {
  535. return time.Since(line.started)
  536. }
  537. // 绑定端口,建立服务
  538. // 需要程序运行时调用
  539. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  540. doConnectFunc := func(conn conn.Connect) {
  541. proto, version, channel, auth, err := conn.ReadAuthInfo()
  542. if err != nil {
  543. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  544. conn.Close()
  545. return
  546. }
  547. if version != info.Version || proto != info.Proto {
  548. log.Println("wrong version or protocol: ", version, proto)
  549. conn.Close()
  550. return
  551. }
  552. if !h.checkAuthFunc(proto, version, channel, auth) {
  553. conn.Close()
  554. return
  555. }
  556. // 发送频道信息
  557. if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
  558. log.Println("[WriteAuthInfo ERROR]", err)
  559. conn.Close()
  560. return
  561. }
  562. // 将连接加入现有连接中
  563. done := false
  564. h.lines.Range(func(id int, line *Line) bool {
  565. if line.state == Disconnected && line.host == nil && line.ChannelEqualWithoutPrefix(channel) {
  566. line.Start(conn, nil)
  567. done = true
  568. return false
  569. }
  570. return true
  571. })
  572. // 新建一个连接
  573. if !done {
  574. line := NewConnect(h.cf, h, channel, conn, nil)
  575. h.addLine(line)
  576. }
  577. }
  578. if info.Version == ws2.VERSION && info.Proto == ws2.PROTO {
  579. bind := ""
  580. if info.Bind != "" {
  581. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  582. }
  583. return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  584. } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO {
  585. return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  586. }
  587. return errors.New("not connect protocol and version found")
  588. }
  589. // 新建一个连接,不同的连接协议由底层自己选择
  590. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  591. func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err error) {
  592. // 检查当前channel是否已经存在
  593. if !force {
  594. line := h.ChannelToLine(channel)
  595. if line != nil && line.state == Connected {
  596. err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  597. return
  598. }
  599. }
  600. if host == nil {
  601. if h.connectHostFunc == nil {
  602. return errors.New("not connect host func found")
  603. }
  604. // 获取服务地址等信息
  605. host, err = h.connectHostFunc(channel, Both)
  606. if err != nil {
  607. return err
  608. }
  609. }
  610. var conn conn.Connect
  611. var runProto string
  612. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  613. // 添加定时器
  614. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
  615. defer cancel()
  616. taskCh := make(chan bool)
  617. done := false
  618. go func() {
  619. if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
  620. runProto = ws2.PROTO
  621. conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
  622. } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
  623. runProto = tcp2.PROTO
  624. conn, err = tcp2.Dial(h.cf, addr, host.Hash)
  625. } else {
  626. err = fmt.Errorf("not correct protocol and version found in: %+v", host)
  627. }
  628. if done {
  629. if err != nil {
  630. log.Println("[Dial ERROR]", err)
  631. }
  632. if conn != nil {
  633. conn.Close()
  634. }
  635. } else {
  636. taskCh <- err == nil
  637. }
  638. }()
  639. select {
  640. case ok := <-taskCh:
  641. cancel()
  642. if !ok || err != nil || conn == nil {
  643. log.Println("[Client ERROR]", host.Proto, err)
  644. host.Errors++
  645. host.Updated = time.Now()
  646. if err == nil {
  647. err = errors.New("unknown error")
  648. }
  649. return err
  650. }
  651. case <-ctx.Done():
  652. done = true
  653. return errors.New("timeout")
  654. }
  655. // 发送验证信息
  656. if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
  657. log.Println("[WriteAuthInfo ERROR]", err)
  658. conn.Close()
  659. host.Errors++
  660. host.Updated = time.Now()
  661. return err
  662. }
  663. // 接收频道信息
  664. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  665. if err != nil {
  666. log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
  667. conn.Close()
  668. host.Errors++
  669. host.Updated = time.Now()
  670. return err
  671. }
  672. // 检查版本和协议是否一致
  673. if version != host.Version || proto != runProto {
  674. err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
  675. log.Println(err)
  676. conn.Close()
  677. host.Errors++
  678. host.Updated = time.Now()
  679. return err
  680. }
  681. // 检查频道名称是否匹配
  682. if !strings.Contains(channel2, channel) {
  683. err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
  684. log.Println(err)
  685. conn.Close()
  686. host.Errors++
  687. host.Updated = time.Now()
  688. return err
  689. }
  690. // 检查验证是否合法
  691. if !h.checkAuthFunc(proto, version, channel, auth) {
  692. err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  693. log.Println(err)
  694. conn.Close()
  695. host.Errors++
  696. host.Updated = time.Now()
  697. return err
  698. }
  699. // 更新服务主机信息
  700. host.Errors = 0
  701. host.Updated = time.Now()
  702. // 将连接加入现有连接中
  703. done = false
  704. h.lines.Range(func(id int, line *Line) bool {
  705. if line.channel == channel {
  706. if line.state == Connected {
  707. if force {
  708. line.Close(true)
  709. } else {
  710. err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
  711. log.Println(err)
  712. return false
  713. }
  714. }
  715. line.Start(conn, host)
  716. done = true
  717. return false
  718. }
  719. return true
  720. })
  721. if err != nil {
  722. return err
  723. }
  724. // 新建一个连接
  725. if !done {
  726. line := NewConnect(h.cf, h, channel, conn, host)
  727. h.addLine(line)
  728. }
  729. return nil
  730. }
  731. // 重试方式连接服务
  732. // 将会一直阻塞直到连接成功
  733. func (h *Hub) ConnectToServerX(channel string, force bool, host *HostInfo) {
  734. if host == nil {
  735. if h.connectHostFunc == nil {
  736. log.Println("ConnectToServerX: not connect host func found")
  737. return
  738. }
  739. host, _ = h.connectHostFunc(channel, Direct)
  740. }
  741. for {
  742. err := h.ConnectToServer(channel, force, host)
  743. if err == nil {
  744. return
  745. }
  746. log.Println("[ConnectToServer ERROR, try it again]", channel, host, err)
  747. host = nil
  748. // 产生一个随机数避免刹间重连过载
  749. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  750. time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
  751. }
  752. }
  753. // 检测处理连接状态,只在客户端有效
  754. func (h *Hub) checkConnect() {
  755. // 检查客户端获取主机地址函数
  756. if h.connectHostFunc == nil {
  757. return
  758. }
  759. proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond)))
  760. connectTicker := time.NewTicker(time.Millisecond * time.Duration(h.cf.ConnectCheck))
  761. for {
  762. select {
  763. case <-proxyTicker.C:
  764. now := time.Now().UnixMilli()
  765. h.lines.Range(func(id int, line *Line) bool {
  766. if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) {
  767. host, err := h.connectHostFunc(line.channel, Direct)
  768. if err != nil {
  769. log.Println("[proxyTicker connectHostFunc ERROR]", err)
  770. return false
  771. }
  772. err = h.ConnectToServer(line.channel, true, host)
  773. if err != nil {
  774. log.Println("[checkProxyConnect ConnectToServer WARNING]", err)
  775. }
  776. }
  777. return true
  778. })
  779. case <-connectTicker.C:
  780. h.lines.Range(func(id int, line *Line) bool {
  781. if line.host != nil && line.state == Disconnected {
  782. err := h.ConnectToServer(line.channel, true, nil)
  783. if err != nil {
  784. log.Println("[connectTicker ConnectToServer WARNING]", err)
  785. }
  786. }
  787. return true
  788. })
  789. }
  790. }
  791. }
  792. // 建立一个集线器
  793. // connectFunc 用于监听连接状态的函数,可以为nil
  794. func NewHub(
  795. cf *config.Config,
  796. channel string,
  797. // 客户端需要用的函数 (服务端可空)
  798. connectHostFunc ConnectHostFunc,
  799. // 验证函数,获取认证信息,用于发送给对方
  800. authFunc AuthFunc,
  801. // 核对发送过来的认证信息
  802. checkAuthFunc CheckAuthFunc,
  803. // 连接状态变化时调用的函数
  804. connectStatusFunc ConnectStatusFunc,
  805. // 验证发送数据的条件是否满足 (可为空)
  806. checkConnectOkFunc CheckConnectOkFunc,
  807. ) (h *Hub) {
  808. h = &Hub{
  809. cf: cf,
  810. globalID: uint16(time.Now().UnixNano()) % config.ID_MAX,
  811. channel: channel,
  812. middle: make([]MiddleFunc, 0),
  813. lines: NewMapx(),
  814. connectHostFunc: connectHostFunc,
  815. authFunc: authFunc,
  816. checkAuthFunc: checkAuthFunc,
  817. connectStatusFunc: connectStatusFunc,
  818. checkConnectOkFunc: checkConnectOkFunc,
  819. lastCleanDeadConnect: time.Now().UnixMilli(),
  820. }
  821. go h.checkConnect()
  822. return h
  823. }