hub.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904
  1. package tinymq
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. // "regexp"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. "git.me9.top/git/tinymq/config"
  16. "git.me9.top/git/tinymq/conn"
  17. "git.me9.top/git/tinymq/conn/tcp2"
  18. "git.me9.top/git/tinymq/conn/ws2"
  19. )
  20. // 类似一个插座的功能,管理多个连接
  21. // 一个hub即可以是客户端,同时也可以是服务端
  22. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  23. // 截取部分字符串
  24. func subStr(str string, length int) string {
  25. if len(str) <= length {
  26. return str
  27. }
  28. return str[0:length] + "..."
  29. }
  30. type Hub struct {
  31. sync.Mutex
  32. cf *config.Config
  33. globalID uint16
  34. channel string // 本地频道信息
  35. middle []MiddleFunc // 中间件
  36. // connects sync.Map // map[*Line]bool(true) //记录当前的连接,方便查找
  37. lines *Mapx // 记录当前的连接,统一管理
  38. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  39. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  40. // 客户端需要用的函数(服务端可为空)
  41. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  42. // 返回认证信息,发送到对方
  43. authFunc AuthFunc // 获取认证信息,用于发送给对方
  44. // 核对发送过来的认证信息
  45. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  46. // 连接状态变化时调用的函数
  47. connectStatusFunc ConnectStatusFunc
  48. // 验证发送数据的条件是否满足 (可空)
  49. checkConnectOkFunc CheckConnectOkFunc
  50. // 通过过滤函数获取一个频道信息 (可空)
  51. filterToChannelFunc FilterToChannelFunc
  52. // 上次清理异常连接时间戳
  53. lastCleanDeadConnect int64
  54. }
  55. func (h *Hub) SetCheckConnectOkFunc(fn CheckConnectOkFunc) {
  56. h.checkConnectOkFunc = fn
  57. }
  58. func (h *Hub) SetFilterToChannelFunc(fn FilterToChannelFunc) {
  59. h.filterToChannelFunc = fn
  60. }
  61. // 转换数据
  62. func (h *Hub) convertData(data any) (reqData []byte, err error) {
  63. switch data := data.(type) {
  64. case []byte:
  65. reqData = data
  66. case string:
  67. reqData = []byte(data)
  68. case func() ([]byte, error):
  69. reqData, err = data()
  70. if err != nil {
  71. log.Println(err.Error())
  72. return nil, err
  73. }
  74. default:
  75. if data != nil {
  76. // 自动转换数据为json格式
  77. reqData, err = json.Marshal(data)
  78. if err != nil {
  79. log.Println(err.Error())
  80. return nil, err
  81. }
  82. }
  83. }
  84. return
  85. }
  86. // 清理异常连接
  87. func (h *Hub) cleanDeadConnect() {
  88. now := time.Now().UnixMilli()
  89. expired := now - int64(h.cf.CleanDeadConnectWait)
  90. if h.lastCleanDeadConnect < expired {
  91. h.lastCleanDeadConnect = now
  92. h.lines.DeleteInvalidLines(expired)
  93. }
  94. }
  95. // 获取通讯消息ID号
  96. func (h *Hub) GetID() uint16 {
  97. h.Lock()
  98. defer h.Unlock()
  99. h.globalID++
  100. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  101. h.globalID = 1
  102. }
  103. for {
  104. // 检查是否在请求队列中存在对应的id
  105. if _, ok := h.msgCache.Load(h.globalID); ok {
  106. h.globalID++
  107. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  108. h.globalID = 1
  109. }
  110. } else {
  111. break
  112. }
  113. }
  114. return h.globalID
  115. }
  116. // 添加中间件
  117. // 如果中间件函数返回为空,表示处理完成,通过
  118. // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过
  119. func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
  120. h.middle = append(h.middle, middleFunc)
  121. }
  122. // 注册频道,其中频道为正则表达式字符串
  123. func (h *Hub) Subscribe(filter FilterFunc, cmd string, backFunc SubscribeBackFunc) (err error) {
  124. if filter == nil {
  125. return errors.New("filter function can not be nil")
  126. }
  127. reg := &SubscribeData{
  128. Filter: filter,
  129. Cmd: cmd,
  130. BackFunc: backFunc,
  131. }
  132. sub, ok := h.subscribes.Load(cmd)
  133. if ok {
  134. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  135. return
  136. }
  137. regs := make([]*SubscribeData, 1)
  138. regs[0] = reg
  139. h.subscribes.Store(cmd, regs)
  140. return
  141. }
  142. // 遍历频道列表
  143. // 如果 fn 返回 false,则 range 停止迭代
  144. func (h *Hub) ConnectRange(fn func(id int, line *Line) bool) {
  145. h.lines.Range(fn)
  146. }
  147. // 获取当前在线的数量
  148. func (h *Hub) ConnectNum() int {
  149. var count int
  150. h.lines.Range(func(id int, line *Line) bool {
  151. if line.state == Connected {
  152. count++
  153. }
  154. return true
  155. })
  156. return count
  157. }
  158. // 获取所有的在线连接频道
  159. func (h *Hub) AllChannel() []string {
  160. cs := make([]string, 0)
  161. h.lines.Range(func(id int, line *Line) bool {
  162. if line.state == Connected {
  163. cs = append(cs, line.channel)
  164. }
  165. return true
  166. })
  167. return cs
  168. }
  169. // 获取所有连接频道和连接时长
  170. // 为了避免定义数据结构麻烦,采用|隔开, 频道名|连接开始时间
  171. func (h *Hub) AllChannelWithStarted() []string {
  172. cs := make([]string, 0)
  173. h.lines.Range(func(id int, line *Line) bool {
  174. if line.state == Connected {
  175. cs = append(cs, fmt.Sprintf("%s|%d", line.channel, line.updated.UnixMilli()))
  176. }
  177. return true
  178. })
  179. return cs
  180. }
  181. // 获取频道并通过函数过滤,如果返回 false 将终止
  182. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  183. h.lines.Range(func(id int, line *Line) bool {
  184. if line.state == Connected {
  185. return fn(line.channel)
  186. }
  187. return true
  188. })
  189. }
  190. // 从 channel 获取连接
  191. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  192. h.lines.Range(func(id int, l *Line) bool {
  193. if l.IsChannelEqual(channel) {
  194. line = l
  195. return false
  196. }
  197. return true
  198. })
  199. return
  200. }
  201. // 返回请求结果
  202. func (h *Hub) outResponse(response *ResponseData) {
  203. defer recover() //避免管道已经关闭而引起panic
  204. id := response.Id
  205. t, ok := h.msgCache.Load(id)
  206. if ok {
  207. // 删除数据缓存
  208. h.msgCache.Delete(id)
  209. gm := t.(*GetMsg)
  210. // 停止定时器
  211. if !gm.timer.Stop() {
  212. select {
  213. case <-gm.timer.C:
  214. default:
  215. }
  216. }
  217. // 回应数据到上层
  218. gm.out <- response
  219. }
  220. }
  221. // 发送数据到网络接口
  222. // 返回发送的数量
  223. func (h *Hub) sendRequest(gd *GetData) (count int) {
  224. outData, err := h.convertData(gd.Data)
  225. if err != nil {
  226. log.Println(err)
  227. return 0
  228. }
  229. doit := func(_ int, line *Line) bool {
  230. // 检查连接是否OK
  231. if line.state != Connected {
  232. return true
  233. }
  234. // 验证连接是否达到发送数据的要求
  235. if h.checkConnectOkFunc != nil && !h.checkConnectOkFunc(line, gd) {
  236. return true
  237. }
  238. if gd.Filter(line) {
  239. var id uint16
  240. if gd.backchan != nil {
  241. id = h.GetID()
  242. timeout := gd.Timeout
  243. if timeout <= 0 {
  244. timeout = h.cf.WriteWait
  245. }
  246. fn := func(id uint16, conn *Line) func() {
  247. return func() {
  248. go h.outResponse(&ResponseData{
  249. Id: id,
  250. State: config.GET_TIMEOUT,
  251. Data: fmt.Appendf(nil, "[%s] %s %s", config.IdMsg(config.GET_TIMEOUT), conn.channel, gd.Cmd),
  252. conn: conn,
  253. })
  254. // 检查是否已经很久时间没有使用连接了
  255. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) {
  256. // 超时关闭当前的连接
  257. log.Println("get message timeout", conn.channel)
  258. // 有可能连接出现问题,断开并重新连接
  259. conn.Close(false)
  260. return
  261. }
  262. }
  263. }(id, line)
  264. // 将要发送的请求缓存
  265. gm := &GetMsg{
  266. out: gd.backchan,
  267. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  268. }
  269. h.msgCache.Store(id, gm)
  270. }
  271. // 组织数据并发送到Connect
  272. line.sendRequest <- &RequestData{
  273. Id: id,
  274. Cmd: gd.Cmd,
  275. Data: outData,
  276. timeout: gd.Timeout,
  277. backchan: gd.backchan,
  278. conn: line,
  279. }
  280. if h.cf.PrintMsg {
  281. log.Println("[SEND]->", id, line.channel, "["+gd.Cmd+"]", subStr(string(outData), 200))
  282. }
  283. count++
  284. if gd.Max > 0 && count >= gd.Max {
  285. return false
  286. }
  287. }
  288. return true
  289. }
  290. // 如果没有发送到消息,延时重连直到超时
  291. for i := 0; i <= gd.Timeout; i += 500 {
  292. if gd.Rand {
  293. h.lines.RandRange(doit, i == 0)
  294. } else {
  295. h.lines.Range(doit)
  296. }
  297. if count > 0 {
  298. break
  299. }
  300. // 如果是客户端,并且有机会自动连接,则尝试自动连接
  301. if i == 0 && h.connectHostFunc != nil && h.filterToChannelFunc != nil {
  302. channel := h.filterToChannelFunc(gd.Filter)
  303. err := h.ConnectToServer(channel, false, nil)
  304. if err != nil {
  305. log.Println(err)
  306. return 0
  307. }
  308. } else {
  309. time.Sleep(time.Millisecond * 500)
  310. }
  311. }
  312. return
  313. }
  314. // 执行网络发送过来的命令
  315. func (h *Hub) requestFromNet(request *RequestData) {
  316. cmd := request.Cmd
  317. channel := request.conn.channel
  318. if h.cf.PrintMsg {
  319. log.Println("[REQU]<-", request.Id, channel, "["+cmd+"]", subStr(string(request.Data), 200))
  320. }
  321. // 执行中间件
  322. for _, mdFunc := range h.middle {
  323. rsp := mdFunc(request)
  324. if rsp != nil {
  325. // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理
  326. if rsp.State == config.NEXT_MIDDLE {
  327. continue
  328. }
  329. // 返回消息
  330. if request.Id != 0 {
  331. rsp.Id = request.Id
  332. request.conn.sendResponse <- rsp
  333. }
  334. return
  335. } else {
  336. break
  337. }
  338. }
  339. sub, ok := h.subscribes.Load(cmd)
  340. if ok {
  341. subs := sub.([]*SubscribeData)
  342. // 倒序查找是为了新增的频道响应函数优先执行
  343. for i := len(subs) - 1; i >= 0; i-- {
  344. rg := subs[i]
  345. if rg.Filter(request.conn) {
  346. state, data := rg.BackFunc(request)
  347. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  348. if state == config.NEXT_SUBSCRIBE {
  349. continue
  350. }
  351. var byteData []byte
  352. switch data := data.(type) {
  353. case []byte:
  354. byteData = data
  355. case string:
  356. byteData = []byte(data)
  357. default:
  358. if data != nil {
  359. // 自动转换数据为json格式
  360. var err error
  361. byteData, err = json.Marshal(data)
  362. if err != nil {
  363. log.Println(err.Error())
  364. state = config.CONVERT_FAILED
  365. byteData = fmt.Appendf(nil, "[%s] %s %s", config.IdMsg(config.CONVERT_FAILED), request.conn.channel, request.Cmd)
  366. }
  367. }
  368. }
  369. // 如果id为0表示不需要回应
  370. if request.Id != 0 {
  371. request.conn.sendResponse <- &ResponseData{
  372. Id: request.Id,
  373. State: state,
  374. Data: byteData,
  375. }
  376. if h.cf.PrintMsg {
  377. log.Println("[RESP]->", request.Id, channel, "["+cmd+"]", state, subStr(string(byteData), 200))
  378. }
  379. }
  380. return
  381. }
  382. }
  383. }
  384. log.Println("[not match command]", channel, cmd)
  385. // 返回没有匹配的消息
  386. request.conn.sendResponse <- &ResponseData{
  387. Id: request.Id,
  388. State: config.NO_MATCH,
  389. Data: fmt.Appendf(nil, "[%s] %s %s", config.IdMsg(config.NO_MATCH), channel, cmd),
  390. }
  391. }
  392. // 请求频道并获取数据,采用回调的方式返回结果
  393. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  394. // 如果 backFunc 返回为 false 则提前结束
  395. // 最大数量和超时时间如果为0的话表示使用默认值
  396. func (h *Hub) GetWithStruct(gd *GetData, backFunc GetBackFunc) (count int) {
  397. if gd.Filter == nil {
  398. return 0
  399. }
  400. if gd.Timeout <= 0 {
  401. gd.Timeout = h.cf.WriteWait
  402. }
  403. if gd.backchan == nil {
  404. gd.backchan = make(chan *ResponseData, 32)
  405. }
  406. sendMax := h.sendRequest(gd)
  407. if sendMax <= 0 {
  408. return 0
  409. }
  410. // 避免出现异常时线程无法退出
  411. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  412. defer func() {
  413. if !timer.Stop() {
  414. select {
  415. case <-timer.C:
  416. default:
  417. }
  418. }
  419. close(gd.backchan)
  420. }()
  421. for {
  422. select {
  423. case rp := <-gd.backchan:
  424. if rp == nil || rp.conn == nil {
  425. // 可能是已经退出了
  426. return
  427. }
  428. ch := rp.conn.channel
  429. if h.cf.PrintMsg {
  430. log.Println("[RECV]<-", rp.Id, ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  431. }
  432. count++
  433. // 如果这里返回为false这跳出循环
  434. if backFunc != nil && !backFunc(rp) {
  435. return
  436. }
  437. if count >= sendMax {
  438. return
  439. }
  440. case <-timer.C:
  441. return
  442. }
  443. }
  444. }
  445. // 请求频道并获取数据,采用回调的方式返回结果
  446. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  447. // 如果 backFunc 返回为 false 则提前结束
  448. func (h *Hub) Get(filter FilterFunc, cmd string, data any, backFunc GetBackFunc) (count int) {
  449. return h.GetWithStruct(&GetData{
  450. Filter: filter,
  451. Cmd: cmd,
  452. Data: data,
  453. }, backFunc)
  454. }
  455. // 获取一个数据,阻塞等待到超时间隔
  456. func (h *Hub) GetOneWithStruct(gd *GetData) (response *ResponseData) {
  457. if gd.Filter == nil {
  458. return &ResponseData{
  459. State: config.CONNECT_NO_MATCH,
  460. Data: fmt.Appendf(nil, "[%s] not filter function", config.IdMsg(config.CONNECT_NO_MATCH)),
  461. }
  462. }
  463. gd.Max = 1
  464. h.GetWithStruct(gd, func(rp *ResponseData) (ok bool) {
  465. response = rp
  466. return false
  467. })
  468. if response == nil {
  469. return &ResponseData{
  470. State: config.CONNECT_NO_MATCH,
  471. Data: fmt.Appendf(nil, "[%s] %s", config.IdMsg(config.CONNECT_NO_MATCH), gd.Cmd),
  472. }
  473. }
  474. return
  475. }
  476. // 只获取一个频道的数据,阻塞等待到默认超时间隔
  477. // 如果没有结果将返回 NO_MATCH
  478. func (h *Hub) GetOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  479. return h.GetOneWithStruct(&GetData{
  480. Filter: filter,
  481. Cmd: cmd,
  482. Data: data,
  483. Max: 1,
  484. })
  485. }
  486. func (h *Hub) GetRandOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  487. return h.GetOneWithStruct(&GetData{
  488. Filter: filter,
  489. Cmd: cmd,
  490. Data: data,
  491. Max: 1,
  492. Rand: true,
  493. })
  494. }
  495. // 只获取一个频道的数据,阻塞等待到指定超时间隔
  496. // 如果没有结果将返回 NO_MATCH
  497. func (h *Hub) GetOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  498. return h.GetOneWithStruct(&GetData{
  499. Filter: filter,
  500. Cmd: cmd,
  501. Data: data,
  502. Max: 1,
  503. Timeout: timeout,
  504. })
  505. }
  506. func (h *Hub) GetRandOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  507. return h.GetOneWithStruct(&GetData{
  508. Filter: filter,
  509. Cmd: cmd,
  510. Data: data,
  511. Max: 1,
  512. Timeout: timeout,
  513. Rand: true,
  514. })
  515. }
  516. // 推送消息出去,不需要返回数据
  517. func (h *Hub) Push(filter FilterFunc, cmd string, data any) {
  518. h.PushWithMax(filter, cmd, data, 0)
  519. }
  520. // 推送最大对应数量的消息出去,不需要返回数据
  521. func (h *Hub) PushWithMax(filter FilterFunc, cmd string, data any, max int) {
  522. // 排除空频道
  523. if filter == nil {
  524. return
  525. }
  526. gd := &GetData{
  527. Filter: filter,
  528. Cmd: cmd,
  529. Data: data,
  530. Max: max,
  531. Timeout: h.cf.WriteWait,
  532. backchan: nil,
  533. }
  534. h.sendRequest(gd)
  535. }
  536. // 增加连接
  537. func (h *Hub) addLine(line *Line) {
  538. if h.lines.Exist(line) {
  539. log.Println("connect have exist")
  540. // 连接已经存在,直接返回
  541. return
  542. }
  543. h.lines.Store(line)
  544. }
  545. // 删除连接
  546. func (h *Hub) removeLine(line *Line) {
  547. line.Close(true)
  548. h.lines.Delete(line)
  549. }
  550. // 获取指定连接的连接持续时间
  551. func (h *Hub) ConnectDuration(line *Line) time.Duration {
  552. return time.Since(line.started)
  553. }
  554. // 绑定端口,建立服务
  555. // 需要程序运行时调用
  556. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  557. doConnectFunc := func(conn conn.Connect) {
  558. proto, version, channel, auth, err := conn.ReadAuthInfo()
  559. if err != nil {
  560. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  561. conn.Close()
  562. return
  563. }
  564. if version != info.Version || proto != info.Proto {
  565. log.Println("wrong version or protocol: ", version, proto)
  566. conn.Close()
  567. return
  568. }
  569. if !h.checkAuthFunc(false, proto, version, channel, auth) {
  570. err = fmt.Errorf("[server checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  571. log.Println(err)
  572. conn.Close()
  573. return
  574. }
  575. // 发送频道信息
  576. if err := conn.WriteAuthInfo(h.channel, h.authFunc(false, proto, version, channel, auth)); err != nil {
  577. log.Println("[WriteAuthInfo ERROR]", err)
  578. conn.Close()
  579. return
  580. }
  581. // 将连接加入现有连接中
  582. done := false
  583. h.lines.Range(func(id int, line *Line) bool {
  584. if line.state == Disconnected && line.host == nil && line.IsChannelEqual(channel) {
  585. line.Start(channel, conn, nil)
  586. done = true
  587. return false
  588. }
  589. return true
  590. })
  591. // 新建一个连接
  592. if !done {
  593. line := NewConnect(h.cf, h, channel, conn, nil)
  594. h.addLine(line)
  595. }
  596. }
  597. if info.Version == ws2.VERSION && info.Proto == ws2.PROTO {
  598. bind := ""
  599. if info.Bind != "" {
  600. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  601. }
  602. return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  603. } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO {
  604. return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  605. }
  606. return errors.New("not connect protocol and version found")
  607. }
  608. // 新建一个连接,不同的连接协议由底层自己选择
  609. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  610. func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err error) {
  611. // 检查当前channel是否已经存在
  612. if !force {
  613. line := h.ChannelToLine(channel)
  614. if line != nil && line.state == Connected {
  615. // err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  616. log.Println("[ConnectToServer] channel existed:", channel)
  617. return
  618. }
  619. }
  620. if host == nil {
  621. if h.connectHostFunc == nil {
  622. return errors.New("not connect host func found")
  623. }
  624. // 获取服务地址等信息
  625. host, err = h.connectHostFunc(channel, Both)
  626. if err != nil {
  627. return err
  628. }
  629. }
  630. var conn conn.Connect
  631. var runProto string
  632. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  633. // 添加定时器
  634. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
  635. defer cancel()
  636. taskCh := make(chan bool)
  637. done := false
  638. go func() {
  639. if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
  640. runProto = ws2.PROTO
  641. conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
  642. } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
  643. runProto = tcp2.PROTO
  644. conn, err = tcp2.Dial(h.cf, addr, host.Hash)
  645. } else {
  646. err = fmt.Errorf("not correct protocol and version found in: %+v", host)
  647. }
  648. if done {
  649. if err != nil {
  650. log.Println("[Dial ERROR]", err)
  651. }
  652. if conn != nil {
  653. conn.Close()
  654. }
  655. } else {
  656. taskCh <- err == nil
  657. }
  658. }()
  659. select {
  660. case ok := <-taskCh:
  661. cancel()
  662. if !ok || err != nil || conn == nil {
  663. log.Println("[Client ERROR]", host.Proto, err)
  664. host.Errors++
  665. host.Updated = time.Now()
  666. if err == nil {
  667. err = errors.New("unknown error")
  668. }
  669. return err
  670. }
  671. case <-ctx.Done():
  672. done = true
  673. return errors.New("timeout")
  674. }
  675. // 如果 host 是代理,将代理信息添加到channel中
  676. localChannel := h.channel
  677. if host.Proxy {
  678. localChannel = localChannel + "?proxy=" + host.Host
  679. }
  680. // 发送验证信息
  681. if err := conn.WriteAuthInfo(localChannel, h.authFunc(true, runProto, host.Version, channel, nil)); err != nil {
  682. log.Println("[WriteAuthInfo ERROR]", err)
  683. conn.Close()
  684. host.Errors++
  685. host.Updated = time.Now()
  686. return err
  687. }
  688. // 接收频道信息
  689. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  690. if err != nil {
  691. log.Println("[ConnectToServer ReadAuthInfo ERROR]", err)
  692. conn.Close()
  693. host.Errors++
  694. host.Updated = time.Now()
  695. return err
  696. }
  697. // 检查版本和协议是否一致
  698. if version != host.Version || proto != runProto {
  699. err = fmt.Errorf("[version or protocol wrong ERROR] %d, %s", version, proto)
  700. log.Println(err)
  701. conn.Close()
  702. host.Errors++
  703. host.Updated = time.Now()
  704. return err
  705. }
  706. // 检查频道名称是否匹配
  707. if !strings.Contains(channel2, channel) {
  708. err = fmt.Errorf("[channel ERROR] want %s, get %s", channel, channel2)
  709. log.Println(err)
  710. conn.Close()
  711. host.Errors++
  712. host.Updated = time.Now()
  713. return err
  714. }
  715. // 检查验证是否合法
  716. if !h.checkAuthFunc(true, proto, version, channel, auth) {
  717. err = fmt.Errorf("[client checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  718. log.Println(err)
  719. conn.Close()
  720. host.Errors++
  721. host.Updated = time.Now()
  722. return err
  723. }
  724. // 更新服务主机信息
  725. host.Errors = 0
  726. host.Updated = time.Now()
  727. // 将连接加入现有连接中
  728. done = false
  729. h.lines.Range(func(id int, line *Line) bool {
  730. if line.channel == channel {
  731. if line.state == Connected {
  732. if force {
  733. line.Close(true)
  734. } else {
  735. err = fmt.Errorf("[connectToServer ERROR] channel already connected: %s", channel)
  736. log.Println(err)
  737. return false
  738. }
  739. }
  740. line.Start(channel, conn, host)
  741. done = true
  742. return false
  743. }
  744. return true
  745. })
  746. if err != nil {
  747. return err
  748. }
  749. // 新建一个连接
  750. if !done {
  751. line := NewConnect(h.cf, h, channel, conn, host)
  752. h.addLine(line)
  753. }
  754. return nil
  755. }
  756. // 重试方式连接服务
  757. // 将会一直阻塞直到连接成功
  758. func (h *Hub) ConnectToServerX(channel string, force bool, host *HostInfo) {
  759. if host == nil {
  760. if h.connectHostFunc == nil {
  761. log.Println("ConnectToServerX: not connect host func found")
  762. return
  763. }
  764. host, _ = h.connectHostFunc(channel, Direct)
  765. }
  766. for {
  767. err := h.ConnectToServer(channel, force, host)
  768. if err == nil {
  769. return
  770. }
  771. log.Println("[ConnectToServer ERROR, try it again]", channel, host, err)
  772. host = nil
  773. // 产生一个随机数避免刹间重连过载
  774. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  775. time.Sleep(time.Duration(r.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond)
  776. }
  777. }
  778. // 检测处理连接状态,只在客户端有效
  779. func (h *Hub) checkConnect() {
  780. // 检查客户端获取主机地址函数
  781. if h.connectHostFunc == nil {
  782. return
  783. }
  784. proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond)))
  785. connectTicker := time.NewTicker(time.Millisecond * time.Duration(h.cf.ConnectCheck))
  786. for {
  787. select {
  788. case <-proxyTicker.C:
  789. now := time.Now().UnixMilli()
  790. h.lines.Range(func(id int, line *Line) bool {
  791. if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) {
  792. host, err := h.connectHostFunc(line.channel, Direct)
  793. if err != nil {
  794. log.Println("[proxyTicker connectHostFunc ERROR]", err)
  795. return false
  796. }
  797. err = h.ConnectToServer(line.channel, true, host)
  798. if err != nil {
  799. log.Println("[checkProxyConnect ConnectToServer WARNING]", err)
  800. }
  801. }
  802. return true
  803. })
  804. case <-connectTicker.C:
  805. h.lines.Range(func(id int, line *Line) bool {
  806. if line.host != nil && line.state == Disconnected {
  807. err := h.ConnectToServer(line.channel, true, nil)
  808. if err != nil {
  809. log.Println("[connectTicker ConnectToServer WARNING]", err)
  810. }
  811. }
  812. return true
  813. })
  814. }
  815. }
  816. }
  817. // 建立一个集线器
  818. // connectFunc 用于监听连接状态的函数,可以为nil
  819. func NewHub(
  820. cf *config.Config,
  821. channel string,
  822. // 客户端需要用的函数,提供连接的主机信息 (服务端可空)
  823. connectHostFunc ConnectHostFunc,
  824. // 验证函数,获取认证信息,用于发送给对方
  825. authFunc AuthFunc,
  826. // 核对发送过来的认证信息
  827. checkAuthFunc CheckAuthFunc,
  828. // 连接状态变化时调用的函数
  829. connectStatusFunc ConnectStatusFunc,
  830. // 验证发送数据的条件是否满足 (可为空)
  831. // checkConnectOkFunc CheckConnectOkFunc,
  832. ) (h *Hub) {
  833. if cf == nil {
  834. cf = config.NewConfig()
  835. }
  836. h = &Hub{
  837. cf: cf,
  838. globalID: uint16(time.Now().UnixNano()) % config.ID_MAX,
  839. channel: channel,
  840. middle: make([]MiddleFunc, 0),
  841. lines: NewMapx(),
  842. connectHostFunc: connectHostFunc,
  843. authFunc: authFunc,
  844. checkAuthFunc: checkAuthFunc,
  845. connectStatusFunc: connectStatusFunc,
  846. // checkConnectOkFunc: checkConnectOkFunc,
  847. lastCleanDeadConnect: time.Now().UnixMilli(),
  848. }
  849. go h.checkConnect()
  850. return h
  851. }