hub.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950
  1. package tinymq
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. // "regexp"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. "git.me9.top/git/tinymq/config"
  16. "git.me9.top/git/tinymq/conn"
  17. "git.me9.top/git/tinymq/conn/tcp2"
  18. "git.me9.top/git/tinymq/conn/ws2"
  19. )
  20. // 类似一个插座的功能,管理多个连接
  21. // 一个hub即可以是客户端,同时也可以是服务端
  22. // 为了简化流程和让通讯更加迅速,不再重发和缓存结果,采用超时的方式告诉应用层。
  23. // 截取部分字符串
  24. func subStr(str string, length int) string {
  25. if len(str) <= length {
  26. return str
  27. }
  28. return str[0:length] + "..."
  29. }
  30. type Hub struct {
  31. sync.Mutex
  32. ctx context.Context // 为了方便退出而建立
  33. cancel context.CancelFunc
  34. cf *config.Config
  35. globalID uint16
  36. channel string // 本地频道信息
  37. middle []MiddleFunc // 中间件
  38. lines *Mapx // 记录当前的连接,统一管理
  39. subscribes sync.Map // [cmd]->[]*SubscribeData //注册绑定频道的函数,用于响应请求
  40. msgCache sync.Map // map[uint16]*GetMsg //请求的回应记录,key为id
  41. // 客户端需要用的函数(服务端可为空)
  42. connectHostFunc ConnectHostFunc // 获取对应频道的一个连接地址
  43. // 返回认证信息,发送到对方
  44. authFunc AuthFunc // 获取认证信息,用于发送给对方
  45. // 核对发送过来的认证信息
  46. checkAuthFunc CheckAuthFunc // 核对认证是否合法
  47. // 连接状态变化时调用的函数
  48. connectStatusFunc ConnectStatusFunc
  49. // 验证发送数据的条件是否满足 (可空)
  50. checkConnectOkFunc CheckConnectOkFunc
  51. // 通过过滤函数获取一个频道信息 (可空)
  52. filterToChannelFunc FilterToChannelFunc
  53. // 上次清理异常连接时间戳
  54. lastCleanDeadConnect int64
  55. }
  56. func (h *Hub) SetCheckConnectOkFunc(fn CheckConnectOkFunc) {
  57. h.checkConnectOkFunc = fn
  58. }
  59. func (h *Hub) SetFilterToChannelFunc(fn FilterToChannelFunc) {
  60. h.filterToChannelFunc = fn
  61. }
  62. // 转换数据
  63. func (h *Hub) convertData(data any) (reqData []byte, err error) {
  64. switch data := data.(type) {
  65. case []byte:
  66. reqData = data
  67. case string:
  68. reqData = []byte(data)
  69. case func() ([]byte, error):
  70. reqData, err = data()
  71. if err != nil {
  72. log.Println(err.Error())
  73. return nil, err
  74. }
  75. default:
  76. if data != nil {
  77. // 自动转换数据为json格式
  78. reqData, err = json.Marshal(data)
  79. if err != nil {
  80. log.Println(err.Error())
  81. return nil, err
  82. }
  83. }
  84. }
  85. return
  86. }
  87. // 清理异常连接
  88. func (h *Hub) cleanDeadConnect() {
  89. now := time.Now().UnixMilli()
  90. expired := now - int64(h.cf.CleanDeadConnectWait)
  91. if h.lastCleanDeadConnect < expired {
  92. h.lastCleanDeadConnect = now
  93. h.lines.DeleteInvalidLines(expired)
  94. }
  95. }
  96. // 获取通讯消息ID号
  97. func (h *Hub) GetID() uint16 {
  98. h.Lock()
  99. defer h.Unlock()
  100. h.globalID++
  101. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  102. h.globalID = 1
  103. }
  104. for {
  105. // 检查是否在请求队列中存在对应的id
  106. if _, ok := h.msgCache.Load(h.globalID); ok {
  107. h.globalID++
  108. if h.globalID <= 0 || h.globalID >= config.ID_MAX {
  109. h.globalID = 1
  110. }
  111. } else {
  112. break
  113. }
  114. }
  115. return h.globalID
  116. }
  117. // 添加中间件
  118. // 如果中间件函数返回为空,表示处理完成,通过
  119. // 如果中间件函数返回 NEXT_MIDDLE,表示需要下一个中间件函数处理;如果没有下一函数则默认通过
  120. func (h *Hub) UseMiddle(middleFunc MiddleFunc) {
  121. h.middle = append(h.middle, middleFunc)
  122. }
  123. // 注册频道,其中频道为正则表达式字符串
  124. func (h *Hub) Subscribe(filter FilterFunc, cmd string, backFunc SubscribeBackFunc) (err error) {
  125. if filter == nil {
  126. return errors.New("filter function can not be nil")
  127. }
  128. reg := &SubscribeData{
  129. Filter: filter,
  130. Cmd: cmd,
  131. BackFunc: backFunc,
  132. }
  133. sub, ok := h.subscribes.Load(cmd)
  134. if ok {
  135. h.subscribes.Store(cmd, append(sub.([]*SubscribeData), reg))
  136. return
  137. }
  138. regs := make([]*SubscribeData, 1)
  139. regs[0] = reg
  140. h.subscribes.Store(cmd, regs)
  141. return
  142. }
  143. // 遍历频道列表
  144. // 如果 fn 返回 false,则 range 停止迭代
  145. func (h *Hub) ConnectRange(fn func(id int, line *Line) bool) {
  146. h.lines.Range(fn)
  147. }
  148. // 获取当前在线的数量
  149. func (h *Hub) ConnectNum() int {
  150. var count int
  151. h.lines.Range(func(id int, line *Line) bool {
  152. if line.state == Connected {
  153. count++
  154. }
  155. return true
  156. })
  157. return count
  158. }
  159. // 获取所有的在线连接频道
  160. func (h *Hub) AllChannel() []string {
  161. cs := make([]string, 0)
  162. h.lines.Range(func(id int, line *Line) bool {
  163. if line.state == Connected {
  164. cs = append(cs, line.channel)
  165. }
  166. return true
  167. })
  168. return cs
  169. }
  170. // 获取所有连接频道和连接时长
  171. // 为了避免定义数据结构麻烦,采用|隔开, 频道名|连接开始时间
  172. func (h *Hub) AllChannelWithStarted() []string {
  173. cs := make([]string, 0)
  174. h.lines.Range(func(id int, line *Line) bool {
  175. if line.state == Connected {
  176. cs = append(cs, fmt.Sprintf("%s|%d", line.channel, line.updated.UnixMilli()))
  177. }
  178. return true
  179. })
  180. return cs
  181. }
  182. // 获取频道并通过函数过滤,如果返回 false 将终止
  183. func (h *Hub) ChannelToFunc(fn func(string) bool) {
  184. h.lines.Range(func(id int, line *Line) bool {
  185. if line.state == Connected {
  186. return fn(line.channel)
  187. }
  188. return true
  189. })
  190. }
  191. // 从 channel 获取连接
  192. func (h *Hub) ChannelToLine(channel string) (line *Line) {
  193. h.lines.Range(func(id int, l *Line) bool {
  194. if l.IsChannelEqual(channel) {
  195. line = l
  196. return false
  197. }
  198. return true
  199. })
  200. return
  201. }
  202. // 返回请求结果
  203. func (h *Hub) outResponse(response *ResponseData) {
  204. defer recover() //避免管道已经关闭而引起panic
  205. id := response.Id
  206. t, ok := h.msgCache.Load(id)
  207. if ok {
  208. // 删除数据缓存
  209. h.msgCache.Delete(id)
  210. gm := t.(*GetMsg)
  211. // 停止定时器
  212. if !gm.timer.Stop() {
  213. select {
  214. case <-gm.timer.C:
  215. default:
  216. }
  217. }
  218. // 回应数据到上层
  219. gm.out <- response
  220. }
  221. }
  222. // 发送数据到网络接口
  223. // 返回发送的数量
  224. func (h *Hub) sendRequest(gd *GetData) (count int, err error) {
  225. outData, err := h.convertData(gd.Data)
  226. if err != nil {
  227. log.Println(err)
  228. return 0, err
  229. }
  230. doit := func(_ int, line *Line) bool {
  231. // 检查连接是否OK
  232. if line.state != Connected {
  233. return true
  234. }
  235. // 验证连接是否达到发送数据的要求
  236. if h.checkConnectOkFunc != nil && !h.checkConnectOkFunc(line, gd) {
  237. return true
  238. }
  239. if gd.Filter(line) {
  240. var id uint16
  241. if gd.backchan != nil {
  242. id = h.GetID()
  243. timeout := gd.Timeout
  244. if timeout <= 0 {
  245. timeout = h.cf.WriteWait
  246. }
  247. fn := func(id uint16, conn *Line) func() {
  248. return func() {
  249. go h.outResponse(&ResponseData{
  250. Id: id,
  251. State: config.GET_TIMEOUT,
  252. Data: fmt.Appendf(nil, "[%s] %s %s", IdMsg(config.GET_TIMEOUT), conn.channel, gd.Cmd),
  253. conn: conn,
  254. })
  255. // 检查是否已经很久时间没有使用连接了
  256. if time.Since(conn.lastRead) > time.Duration(h.cf.PingInterval*3*int(time.Millisecond)) {
  257. // 超时关闭当前的连接
  258. log.Println("get message timeout", conn.channel)
  259. // 有可能连接出现问题,断开并重新连接
  260. conn.Close(false)
  261. return
  262. }
  263. }
  264. }(id, line)
  265. // 将要发送的请求缓存
  266. gm := &GetMsg{
  267. out: gd.backchan,
  268. timer: time.AfterFunc(time.Millisecond*time.Duration(timeout), fn),
  269. }
  270. h.msgCache.Store(id, gm)
  271. }
  272. // 组织数据并发送到Connect
  273. line.sendRequest <- &RequestData{
  274. Id: id,
  275. Cmd: gd.Cmd,
  276. Data: outData,
  277. timeout: gd.Timeout,
  278. backchan: gd.backchan,
  279. conn: line,
  280. }
  281. if h.cf.PrintMsg {
  282. log.Println("[SEND]->", id, line.channel, "["+gd.Cmd+"]", subStr(string(outData), 200))
  283. }
  284. count++
  285. if gd.Max > 0 && count >= gd.Max {
  286. return false
  287. }
  288. }
  289. return true
  290. }
  291. // 如果没有发送到消息,延时重连直到超时
  292. for i := 0; i <= gd.Timeout; i += 500 {
  293. if gd.Rand {
  294. h.lines.RandRange(doit, i == 0)
  295. } else {
  296. h.lines.Range(doit)
  297. }
  298. if count > 0 {
  299. break
  300. }
  301. // 如果是客户端,并且有机会自动连接,则尝试自动连接
  302. if i == 0 && h.connectHostFunc != nil && h.filterToChannelFunc != nil {
  303. channel := h.filterToChannelFunc(gd.Filter)
  304. if channel == "" {
  305. err = errors.New(IdMsg(NO_MATCH_CONNECT))
  306. log.Println("not channel found")
  307. return 0, err
  308. }
  309. err := h.ConnectToServer(channel, false, nil, false)
  310. if err != nil {
  311. log.Println(err)
  312. return 0, err
  313. }
  314. } else {
  315. time.Sleep(time.Millisecond * 400) // 故意将时间缩小一点
  316. }
  317. }
  318. return
  319. }
  320. // 执行网络发送过来的命令
  321. func (h *Hub) requestFromNet(request *RequestData) {
  322. cmd := request.Cmd
  323. channel := request.conn.channel
  324. if h.cf.PrintMsg {
  325. log.Println("[REQU]<-", request.Id, channel, "["+cmd+"]", subStr(string(request.Data), 200))
  326. }
  327. // 执行中间件
  328. for _, mdFunc := range h.middle {
  329. rsp := mdFunc(request)
  330. if rsp != nil {
  331. // NEXT_MIDDLE 表示当前的函数没有处理完成,还需要下个中间件处理
  332. if rsp.State == config.NEXT_MIDDLE {
  333. continue
  334. }
  335. // 返回消息
  336. if request.Id != 0 {
  337. rsp.Id = request.Id
  338. request.conn.sendResponse <- rsp
  339. }
  340. return
  341. } else {
  342. break
  343. }
  344. }
  345. sub, ok := h.subscribes.Load(cmd)
  346. if ok {
  347. subs := sub.([]*SubscribeData)
  348. // 倒序查找是为了新增的频道响应函数优先执行
  349. for i := len(subs) - 1; i >= 0; i-- {
  350. rg := subs[i]
  351. if rg.Filter(request.conn) {
  352. state, data := rg.BackFunc(request)
  353. // NEXT_SUBSCRIBE 表示当前的函数没有处理完成,还需要下个注册函数处理
  354. if state == config.NEXT_SUBSCRIBE {
  355. continue
  356. }
  357. var byteData []byte
  358. switch data := data.(type) {
  359. case []byte:
  360. byteData = data
  361. case string:
  362. byteData = []byte(data)
  363. default:
  364. if data != nil {
  365. // 自动转换数据为json格式
  366. var err error
  367. byteData, err = json.Marshal(data)
  368. if err != nil {
  369. log.Println(err.Error())
  370. state = config.CONVERT_FAILED
  371. byteData = fmt.Appendf(nil, "[%s] %s %s", IdMsg(config.CONVERT_FAILED), request.conn.channel, request.Cmd)
  372. }
  373. }
  374. }
  375. // 如果id为0表示不需要回应
  376. if request.Id != 0 {
  377. request.conn.sendResponse <- &ResponseData{
  378. Id: request.Id,
  379. State: state,
  380. Data: byteData,
  381. }
  382. if h.cf.PrintMsg {
  383. log.Println("[RESP]->", request.Id, channel, "["+cmd+"]", state, subStr(string(byteData), 200))
  384. }
  385. }
  386. return
  387. }
  388. }
  389. }
  390. log.Println("[not match command]", channel, cmd)
  391. // 返回没有匹配的消息
  392. request.conn.sendResponse <- &ResponseData{
  393. Id: request.Id,
  394. State: config.NO_MATCH_CMD,
  395. Data: fmt.Appendf(nil, "[%s] Channel: %s, Cmd: %s", IdMsg(config.NO_MATCH_CMD), channel, cmd),
  396. }
  397. }
  398. // 请求频道并获取数据,采用回调的方式返回结果
  399. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  400. // 如果 backFunc 返回为 false 则提前结束
  401. // 最大数量和超时时间如果为0的话表示使用默认值
  402. func (h *Hub) GetWithStruct(gd *GetData, backFunc GetBackFunc) (count int, err error) {
  403. if gd.Filter == nil {
  404. return 0, errors.New("not filter found")
  405. }
  406. if gd.Timeout <= 0 {
  407. gd.Timeout = h.cf.WriteWait
  408. }
  409. if gd.backchan == nil {
  410. gd.backchan = make(chan *ResponseData, 32)
  411. }
  412. sendMax, err := h.sendRequest(gd)
  413. if err != nil {
  414. return 0, err
  415. }
  416. if sendMax <= 0 {
  417. return 0, nil
  418. }
  419. // 避免出现异常时线程无法退出
  420. timer := time.NewTimer(time.Millisecond * time.Duration(gd.Timeout+h.cf.WriteWait*2))
  421. defer func() {
  422. if !timer.Stop() {
  423. select {
  424. case <-timer.C:
  425. default:
  426. }
  427. }
  428. close(gd.backchan)
  429. }()
  430. for {
  431. select {
  432. case rp := <-gd.backchan:
  433. if rp == nil || rp.conn == nil {
  434. // 可能是已经退出了
  435. return
  436. }
  437. ch := rp.conn.channel
  438. if h.cf.PrintMsg {
  439. log.Println("[RECV]<-", rp.Id, ch, "["+gd.Cmd+"]", rp.State, subStr(string(rp.Data), 200))
  440. }
  441. count++
  442. // 如果这里返回为false这跳出循环
  443. if backFunc != nil && !backFunc(rp) {
  444. return
  445. }
  446. if count >= sendMax {
  447. return
  448. }
  449. case <-timer.C:
  450. return
  451. case <-h.ctx.Done():
  452. return
  453. }
  454. }
  455. }
  456. // 请求频道并获取数据,采用回调的方式返回结果
  457. // 当前调用将会阻塞,直到命令都执行结束,最后返回执行的数量
  458. // 如果 backFunc 返回为 false 则提前结束
  459. func (h *Hub) Get(filter FilterFunc, cmd string, data any, backFunc GetBackFunc) (count int, err error) {
  460. return h.GetWithStruct(&GetData{
  461. Filter: filter,
  462. Cmd: cmd,
  463. Data: data,
  464. }, backFunc)
  465. }
  466. // 获取一个数据,阻塞等待到超时间隔
  467. func (h *Hub) GetOneWithStruct(gd *GetData) (response *ResponseData) {
  468. if gd.Filter == nil {
  469. return &ResponseData{
  470. State: config.NO_MATCH_FILTER,
  471. Data: []byte(IdMsg(config.NO_MATCH_FILTER)),
  472. }
  473. }
  474. gd.Max = 1
  475. h.GetWithStruct(gd, func(rp *ResponseData) (ok bool) {
  476. response = rp
  477. return false
  478. })
  479. if response == nil {
  480. return &ResponseData{
  481. State: config.NO_MATCH_CONNECT,
  482. Data: []byte(IdMsg(config.NO_MATCH_CONNECT)),
  483. }
  484. }
  485. return
  486. }
  487. // 只获取一个频道的数据,阻塞等待到默认超时间隔
  488. // 如果没有结果将返回 NO_MATCH
  489. func (h *Hub) GetOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  490. return h.GetOneWithStruct(&GetData{
  491. Filter: filter,
  492. Cmd: cmd,
  493. Data: data,
  494. Max: 1,
  495. })
  496. }
  497. func (h *Hub) GetRandOne(filter FilterFunc, cmd string, data any) (response *ResponseData) {
  498. return h.GetOneWithStruct(&GetData{
  499. Filter: filter,
  500. Cmd: cmd,
  501. Data: data,
  502. Max: 1,
  503. Rand: true,
  504. })
  505. }
  506. // 只获取一个频道的数据,阻塞等待到指定超时间隔
  507. // 如果没有结果将返回 NO_MATCH
  508. func (h *Hub) GetOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  509. return h.GetOneWithStruct(&GetData{
  510. Filter: filter,
  511. Cmd: cmd,
  512. Data: data,
  513. Max: 1,
  514. Timeout: timeout,
  515. })
  516. }
  517. func (h *Hub) GetRandOneWithTimeout(filter FilterFunc, cmd string, data any, timeout int) (response *ResponseData) {
  518. return h.GetOneWithStruct(&GetData{
  519. Filter: filter,
  520. Cmd: cmd,
  521. Data: data,
  522. Max: 1,
  523. Timeout: timeout,
  524. Rand: true,
  525. })
  526. }
  527. // 推送消息出去,不需要返回数据
  528. func (h *Hub) Push(filter FilterFunc, cmd string, data any) {
  529. h.PushWithMax(filter, cmd, data, 0)
  530. }
  531. // 推送最大对应数量的消息出去,不需要返回数据
  532. func (h *Hub) PushWithMax(filter FilterFunc, cmd string, data any, max int) {
  533. // 排除空频道
  534. if filter == nil {
  535. return
  536. }
  537. gd := &GetData{
  538. Filter: filter,
  539. Cmd: cmd,
  540. Data: data,
  541. Max: max,
  542. Timeout: h.cf.WriteWait,
  543. backchan: nil,
  544. }
  545. h.sendRequest(gd)
  546. }
  547. // 增加连接
  548. func (h *Hub) addLine(line *Line) {
  549. if h.lines.Exist(line) {
  550. log.Println("connect have exist")
  551. // 连接已经存在,直接返回
  552. return
  553. }
  554. h.lines.Store(line)
  555. }
  556. // 删除连接
  557. func (h *Hub) removeLine(line *Line) {
  558. line.Close(true)
  559. h.lines.Delete(line)
  560. }
  561. // 获取指定连接的连接持续时间
  562. func (h *Hub) ConnectDuration(line *Line) time.Duration {
  563. return time.Since(line.started)
  564. }
  565. // 绑定端口,建立服务
  566. // 需要程序运行时调用
  567. func (h *Hub) BindForServer(info *HostInfo) (err error) {
  568. doConnectFunc := func(conn conn.Connect) {
  569. proto, version, channel, auth, err := conn.ReadAuthInfo()
  570. if err != nil {
  571. log.Println("[BindForServer ReadAuthInfo ERROR]", err)
  572. conn.Close()
  573. return
  574. }
  575. if version != info.Version || proto != info.Proto {
  576. log.Println("wrong version or protocol: ", version, proto)
  577. conn.Close()
  578. return
  579. }
  580. if !h.checkAuthFunc(false, proto, version, channel, auth) {
  581. err = fmt.Errorf("[server checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  582. log.Println(err)
  583. conn.Close()
  584. return
  585. }
  586. // 发送频道信息
  587. if err := conn.WriteAuthInfo(h.channel, h.authFunc(false, proto, version, channel, auth)); err != nil {
  588. log.Println("[WriteAuthInfo ERROR]", err)
  589. conn.Close()
  590. return
  591. }
  592. // 将连接加入现有连接中
  593. done := false
  594. h.lines.Range(func(id int, line *Line) bool {
  595. if line.state == Disconnected && line.host == nil && line.IsChannelEqual(channel) {
  596. line.Start(channel, conn, nil)
  597. done = true
  598. return false
  599. }
  600. return true
  601. })
  602. // 新建一个连接
  603. if !done {
  604. line := NewConnect(h.cf, h, channel, conn, nil, false)
  605. h.addLine(line)
  606. }
  607. }
  608. if info.Version == ws2.VERSION && info.Proto == ws2.PROTO {
  609. bind := ""
  610. if info.Bind != "" {
  611. bind = net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port)))
  612. }
  613. return ws2.Server(h.cf, bind, info.Path, info.Hash, doConnectFunc)
  614. } else if info.Version == tcp2.VERSION && info.Proto == tcp2.PROTO {
  615. return tcp2.Server(h.cf, net.JoinHostPort(info.Bind, strconv.Itoa(int(info.Port))), info.Hash, doConnectFunc)
  616. }
  617. return errors.New("not connect protocol and version found")
  618. }
  619. // 新建一个连接,不同的连接协议由底层自己选择
  620. // channel: 要连接的频道信息,需要能表达频道关键信息的部分
  621. func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo, autoReconnect bool) (err error) {
  622. // 检查当前channel是否已经存在
  623. if !force {
  624. line := h.ChannelToLine(channel)
  625. if line != nil && line.state == Connected {
  626. // err = fmt.Errorf("[ConnectToServer ERROR] existed channel: %s", channel)
  627. log.Println("[ConnectToServer] channel existed:", channel)
  628. return
  629. }
  630. }
  631. if host == nil {
  632. if h.connectHostFunc == nil {
  633. return errors.New("not connect host func found")
  634. }
  635. // 获取服务地址等信息
  636. host, err = h.connectHostFunc(channel, Both)
  637. if err != nil {
  638. return err
  639. }
  640. if host == nil {
  641. // 如果地址为空表示不需要连接,自己返回
  642. return
  643. }
  644. }
  645. var conn conn.Connect
  646. var runProto string
  647. addr := net.JoinHostPort(host.Host, strconv.Itoa(int(host.Port)))
  648. // 添加定时器
  649. ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(h.cf.ConnectTimeout))
  650. defer cancel()
  651. taskCh := make(chan bool)
  652. done := false
  653. go func() {
  654. if host.Version == ws2.VERSION && (host.Proto == ws2.PROTO || host.Proto == ws2.PROTO_STL) {
  655. runProto = ws2.PROTO
  656. if h.cf.PrintMsg {
  657. log.Println("[ConnectToServer] connecting:", host.Proto, addr, host.Path, host.Hash)
  658. }
  659. conn, err = ws2.Dial(h.cf, host.Proto, addr, host.Path, host.Hash)
  660. } else if host.Version == tcp2.VERSION && host.Proto == tcp2.PROTO {
  661. runProto = tcp2.PROTO
  662. if h.cf.PrintMsg {
  663. log.Println("[ConnectToServer] connecting:", host.Proto, addr, host.Hash)
  664. }
  665. conn, err = tcp2.Dial(h.cf, addr, host.Hash)
  666. } else {
  667. err = fmt.Errorf("not correct protocol and version found in: %+v", host)
  668. }
  669. if done {
  670. if err != nil {
  671. log.Println("[ConnectToServer] dial failed:", err)
  672. }
  673. if conn != nil {
  674. conn.Close()
  675. }
  676. } else {
  677. taskCh <- err == nil
  678. }
  679. }()
  680. select {
  681. case ok := <-taskCh:
  682. cancel()
  683. if !ok || err != nil || conn == nil {
  684. log.Println("[ConnectToServer] connect failed:", err)
  685. host.Errors++
  686. host.Updated = time.Now()
  687. if err == nil {
  688. err = errors.New("unknown error")
  689. }
  690. return err
  691. }
  692. case <-ctx.Done():
  693. done = true
  694. return errors.New("timeout")
  695. case <-h.ctx.Done():
  696. return errors.New("quit")
  697. }
  698. // 如果 host 是代理,将代理信息添加到channel中
  699. localChannel := h.channel
  700. if host.Proxy {
  701. localChannel = localChannel + "?proxy=" + host.Host
  702. }
  703. // 发送验证信息
  704. if err := conn.WriteAuthInfo(localChannel, h.authFunc(true, runProto, host.Version, channel, nil)); err != nil {
  705. log.Println("[ConnectToServer] WriteAuthInfo failed:", err)
  706. conn.Close()
  707. host.Errors++
  708. host.Updated = time.Now()
  709. return err
  710. }
  711. // 接收频道信息
  712. proto, version, channel2, auth, err := conn.ReadAuthInfo()
  713. if err != nil {
  714. log.Println("[ConnectToServer] ReadAuthInfo failed:", err)
  715. conn.Close()
  716. host.Errors++
  717. host.Updated = time.Now()
  718. return err
  719. }
  720. // 检查版本和协议是否一致
  721. if version != host.Version || proto != runProto {
  722. err = fmt.Errorf("[ConnectToServer] version or protocol wrong: %d, %s", version, proto)
  723. log.Println(err)
  724. conn.Close()
  725. host.Errors++
  726. host.Updated = time.Now()
  727. return err
  728. }
  729. // 检查频道名称是否匹配
  730. if !strings.Contains(channel2, channel) {
  731. err = fmt.Errorf("[ConnectToServer] channel want %s, get %s", channel, channel2)
  732. log.Println(err)
  733. conn.Close()
  734. host.Errors++
  735. host.Updated = time.Now()
  736. return err
  737. }
  738. // 检查验证是否合法
  739. if !h.checkAuthFunc(true, proto, version, channel, auth) {
  740. err = fmt.Errorf("[ConnectToServer] client checkAuthFunc in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
  741. log.Println(err)
  742. conn.Close()
  743. host.Errors++
  744. host.Updated = time.Now()
  745. return err
  746. }
  747. // 更新服务主机信息
  748. host.Errors = 0
  749. host.Updated = time.Now()
  750. // 将连接加入现有连接中
  751. done = false
  752. h.lines.Range(func(id int, line *Line) bool {
  753. if line.channel == channel {
  754. if line.state == Connected {
  755. if force {
  756. line.Close(true)
  757. } else {
  758. err = fmt.Errorf("[connectToServer] channel already connected: %s", channel)
  759. log.Println(err)
  760. return false
  761. }
  762. }
  763. line.Start(channel, conn, host)
  764. done = true
  765. return false
  766. }
  767. return true
  768. })
  769. if err != nil {
  770. return err
  771. }
  772. // 新建一个连接
  773. if !done {
  774. line := NewConnect(h.cf, h, channel, conn, host, autoReconnect)
  775. h.addLine(line)
  776. }
  777. return nil
  778. }
  779. // 重试方式连接服务
  780. // 将会一直阻塞直到连接成功
  781. func (h *Hub) ConnectToServerX(channel string, force bool, host *HostInfo) {
  782. if host == nil {
  783. if h.connectHostFunc == nil {
  784. log.Println("[ConnectToServerX] not connect host func found")
  785. return
  786. }
  787. host, _ = h.connectHostFunc(channel, Direct)
  788. if host == nil {
  789. log.Println("[ConnectToServerX] not host found for channel:", channel)
  790. return
  791. }
  792. }
  793. for {
  794. err := h.ConnectToServer(channel, force, host, true)
  795. if err == nil {
  796. return
  797. }
  798. log.Println("[ConnectToServerX] connect failed:", channel, host.Url(), err)
  799. // 产生一个随机数避免刹间重连过载
  800. delay := time.Duration(rand.Intn(h.cf.ConnectTimeout)+(h.cf.ConnectTimeout/2)) * time.Millisecond
  801. log.Println("[ConnectToServerX] will reconnect with delay:", delay)
  802. host = nil
  803. time.Sleep(delay)
  804. }
  805. }
  806. // 检测处理连接状态,只在客户端有效
  807. func (h *Hub) checkConnect() {
  808. // 检查客户端获取主机地址函数
  809. if h.connectHostFunc == nil {
  810. return
  811. }
  812. proxyTicker := time.NewTicker(time.Duration(h.cf.ProxyTimeout * int(time.Millisecond)))
  813. connectTicker := time.NewTicker(time.Millisecond * time.Duration(h.cf.ConnectCheck))
  814. for {
  815. select {
  816. case <-proxyTicker.C:
  817. now := time.Now().UnixMilli()
  818. h.lines.Range(func(id int, line *Line) bool {
  819. if line.host != nil && line.host.Proxy && now-line.updated.UnixMilli() > int64(h.cf.ProxyTimeout) {
  820. host, err := h.connectHostFunc(line.channel, Direct)
  821. if err != nil {
  822. log.Println("[proxyTicker connectHostFunc]", err)
  823. return false
  824. }
  825. if host != nil {
  826. err = h.ConnectToServer(line.channel, true, host, line.autoReconnect)
  827. if err != nil {
  828. log.Println("[checkProxyConnect ConnectToServer]", err)
  829. }
  830. }
  831. }
  832. return true
  833. })
  834. case <-connectTicker.C:
  835. h.lines.Range(func(id int, line *Line) bool {
  836. if line.autoReconnect && line.host != nil && line.state == Disconnected {
  837. err := h.ConnectToServer(line.channel, true, nil, true)
  838. if err != nil {
  839. log.Println("[connectTicker ConnectToServer]", err)
  840. }
  841. }
  842. return true
  843. })
  844. case <-h.ctx.Done():
  845. return
  846. }
  847. }
  848. }
  849. // 退出所有的连接
  850. func (h *Hub) Quit() {
  851. h.cancel()
  852. h.lines.Range(func(id int, line *Line) bool {
  853. if line.state == Connected {
  854. line.Close(true)
  855. }
  856. return true
  857. })
  858. }
  859. // 建立一个集线器
  860. // connectFunc 用于监听连接状态的函数,可以为nil
  861. func NewHub(
  862. cf *config.Config,
  863. channel string,
  864. // 客户端需要用的函数,提供连接的主机信息 (服务端可空)
  865. connectHostFunc ConnectHostFunc,
  866. // 验证函数,获取认证信息,用于发送给对方
  867. authFunc AuthFunc,
  868. // 核对发送过来的认证信息
  869. checkAuthFunc CheckAuthFunc,
  870. // 连接状态变化时调用的函数
  871. connectStatusFunc ConnectStatusFunc,
  872. // 验证发送数据的条件是否满足 (可为空)
  873. // checkConnectOkFunc CheckConnectOkFunc,
  874. ) (h *Hub) {
  875. if cf == nil {
  876. cf = config.NewConfig()
  877. }
  878. ctx, cancel := context.WithCancel(context.Background())
  879. h = &Hub{
  880. ctx: ctx,
  881. cancel: cancel,
  882. cf: cf,
  883. globalID: uint16(time.Now().UnixNano()) % config.ID_MAX,
  884. channel: channel,
  885. middle: make([]MiddleFunc, 0),
  886. lines: NewMapx(),
  887. connectHostFunc: connectHostFunc,
  888. authFunc: authFunc,
  889. checkAuthFunc: checkAuthFunc,
  890. connectStatusFunc: connectStatusFunc,
  891. // checkConnectOkFunc: checkConnectOkFunc,
  892. lastCleanDeadConnect: time.Now().UnixMilli(),
  893. }
  894. go h.checkConnect()
  895. return h
  896. }