tcp2.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. package tcp2
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "log"
  9. "net"
  10. "strings"
  11. "time"
  12. "git.me9.top/git/tinymq/config"
  13. "git.me9.top/git/tinymq/conn"
  14. "git.me9.top/git/tinymq/conn/util"
  15. )
  16. const PROTO string = "tcp"
  17. const VERSION uint8 = 2
  18. // 数据包的最大长度
  19. const MAX_LENGTH = 0xFFFF
  20. const MAX2_LENGTH = 0x1FFFFFFF // 500 M,避免申请过大内存
  21. type Tcp2 struct {
  22. cf *config.Config
  23. conn net.Conn
  24. cipher *util.Cipher // 记录当前的加解密类
  25. compress bool // 是否启用压缩
  26. }
  27. // 服务端
  28. // hash 格式 encryptMethod:encryptKey
  29. func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFunc) (err error) {
  30. var ci *util.CipherInfo
  31. var encryptKey string
  32. if hash != "" {
  33. i := strings.Index(hash, ":")
  34. if i <= 0 {
  35. err = errors.New("hash is invalid")
  36. return
  37. }
  38. encryptMethod := hash[0:i]
  39. encryptKey = hash[i+1:]
  40. if c, ok := util.CipherMethod[encryptMethod]; ok {
  41. ci = c
  42. } else {
  43. return errors.New("Unsupported encryption method: " + encryptMethod)
  44. }
  45. }
  46. log.Printf("Listening and serving tcp on %s\n", bind)
  47. l, err := net.Listen("tcp", bind)
  48. if err != nil {
  49. log.Println("[tcp2 Server ERROR]", err)
  50. return
  51. }
  52. go func(l net.Listener) {
  53. defer l.Close()
  54. for {
  55. conn, err := l.Accept()
  56. if err != nil {
  57. log.Println("[accept ERROR]", err)
  58. return
  59. }
  60. go func(conn net.Conn) {
  61. if ci == nil {
  62. c := &Tcp2{
  63. cf: cf,
  64. conn: conn,
  65. }
  66. fn(c)
  67. return
  68. }
  69. var eiv []byte
  70. var div []byte
  71. if ci.IvLen > 0 {
  72. // 服务端 IV
  73. eiv = make([]byte, ci.IvLen)
  74. _, err = rand.Read(eiv)
  75. if err != nil {
  76. log.Println("[tcp2 Server rand.Read ERROR]", err)
  77. return
  78. }
  79. // 发送 IV
  80. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  81. if _, err := conn.Write(eiv); err != nil {
  82. log.Println("[tcp2 Server conn.Write ERROR]", err)
  83. return
  84. }
  85. // 读取 IV
  86. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  87. if err != nil {
  88. log.Println("[tcp2 Server SetReadDeadline ERROR]", err)
  89. return
  90. }
  91. div = make([]byte, ci.IvLen)
  92. _, err := io.ReadFull(conn, div)
  93. if err != nil {
  94. log.Println("[tcp2 Server ReadFull ERROR]", err)
  95. return
  96. }
  97. }
  98. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  99. if err != nil {
  100. log.Println("[tcp2 NewCipher ERROR]", err)
  101. return
  102. }
  103. // 初始化
  104. c := &Tcp2{
  105. cf: cf,
  106. conn: conn,
  107. cipher: cipher,
  108. }
  109. fn(c)
  110. }(conn)
  111. }
  112. }(l)
  113. return
  114. }
  115. // 客户端,新建一个连接
  116. func Dial(cf *config.Config, addr string, hash string) (conn.Connect, error) {
  117. // 没有加密的情况
  118. if hash == "" {
  119. conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond)
  120. if err != nil {
  121. return nil, err
  122. }
  123. c := &Tcp2{
  124. cf: cf,
  125. conn: conn,
  126. }
  127. return c, nil
  128. }
  129. i := strings.Index(hash, ":")
  130. if i <= 0 {
  131. return nil, errors.New("hash is invalid")
  132. }
  133. encryptMethod := hash[0:i]
  134. encryptKey := hash[i+1:]
  135. ci, ok := util.CipherMethod[encryptMethod]
  136. if !ok {
  137. return nil, errors.New("Unsupported encryption method: " + encryptMethod)
  138. }
  139. conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond)
  140. if err != nil {
  141. return nil, err
  142. }
  143. var eiv []byte
  144. var div []byte
  145. if ci.IvLen > 0 {
  146. // 客户端 IV
  147. eiv = make([]byte, ci.IvLen)
  148. _, err = rand.Read(eiv)
  149. if err != nil {
  150. log.Println("[tcp2 Client rand.Read ERROR]", err)
  151. return nil, err
  152. }
  153. // 发送 IV
  154. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  155. if _, err := conn.Write(eiv); err != nil {
  156. log.Println("[tcp2 Client conn.Write ERROR]", err)
  157. return nil, err
  158. }
  159. // 读取 IV
  160. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  161. if err != nil {
  162. log.Println("[tcp2 Client SetReadDeadline ERROR]", err)
  163. return nil, err
  164. }
  165. div = make([]byte, ci.IvLen)
  166. _, err := io.ReadFull(conn, div)
  167. if err != nil {
  168. log.Println("[tcp2 Client ReadFull ERROR]", err)
  169. return nil, err
  170. }
  171. }
  172. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  173. if err != nil {
  174. log.Println("[tcp2 NewCipher ERROR]", err)
  175. return nil, err
  176. }
  177. // 初始化
  178. c := &Tcp2{
  179. cf: cf,
  180. conn: conn,
  181. cipher: cipher,
  182. compress: true,
  183. }
  184. return c, nil
  185. }
  186. // 发送数据到网络
  187. // 如果有加密函数的话会直接修改源数据
  188. func (c *Tcp2) writeMessage(buf []byte) (err error) {
  189. if len(buf) > MAX2_LENGTH {
  190. return fmt.Errorf("data length more than %d", MAX2_LENGTH)
  191. }
  192. if c.cipher != nil {
  193. c.cipher.Encrypt(buf, buf)
  194. }
  195. c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.cf.WriteWait) * time.Millisecond))
  196. for {
  197. n, err := c.conn.Write(buf)
  198. if err != nil {
  199. return err
  200. }
  201. if n < len(buf) {
  202. buf = buf[n:]
  203. } else {
  204. return nil
  205. }
  206. }
  207. }
  208. // 申请内存并写入数据长度信息
  209. // 还多申请一个字节用于保存crc
  210. func (c *Tcp2) writeDataLen(dlen int) (buf []byte, start int) {
  211. if dlen >= MAX_LENGTH {
  212. buf = make([]byte, dlen+2+4+1)
  213. start = 2 + 4
  214. binary.BigEndian.PutUint16(buf[:2], MAX_LENGTH)
  215. binary.BigEndian.PutUint32(buf[2:6], uint32(dlen))
  216. } else {
  217. buf = make([]byte, dlen+2+1)
  218. start = 2
  219. binary.BigEndian.PutUint16(buf[:2], uint16(dlen))
  220. }
  221. return
  222. }
  223. // 发送Auth信息
  224. // 建立连接后第一个发送的消息
  225. func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) {
  226. protoLen := len(PROTO)
  227. if protoLen > 0xFF {
  228. return errors.New("length of protocol over")
  229. }
  230. channelLen := len(channel)
  231. if channelLen > 0xFFFF {
  232. return errors.New("length of channel over")
  233. }
  234. // id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
  235. dlen := 2 + 1 + protoLen + 1 + 1 + 2 + channelLen + len(auth)
  236. buf, start := c.writeDataLen(dlen)
  237. index := start
  238. binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH)
  239. index += 2
  240. buf[index] = byte(protoLen)
  241. index++
  242. copy(buf[index:index+protoLen], []byte(PROTO))
  243. index += protoLen
  244. buf[index] = VERSION
  245. index++
  246. if c.compress {
  247. buf[index] = 0x01
  248. } else {
  249. buf[index] = 0
  250. }
  251. index++
  252. binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
  253. index += 2
  254. copy(buf[index:index+channelLen], []byte(channel))
  255. index += channelLen
  256. copy(buf[index:], auth)
  257. buf[start+dlen] = util.CRC8(buf[start : start+dlen])
  258. return c.writeMessage(buf)
  259. }
  260. // 从连接中读取信息
  261. func (c *Tcp2) readMessage(deadline int) ([]byte, error) {
  262. buf := make([]byte, 2)
  263. err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
  264. if err != nil {
  265. return nil, err
  266. }
  267. // 读取数据流长度
  268. _, err = io.ReadFull(c.conn, buf)
  269. if err != nil {
  270. return nil, err
  271. }
  272. // 将读出来的数据进行解密
  273. if c.cipher != nil {
  274. c.cipher.Decrypt(buf, buf)
  275. }
  276. dlen := uint32(binary.BigEndian.Uint16(buf))
  277. if dlen < 2 {
  278. return nil, errors.New("length is less to 2")
  279. }
  280. if dlen >= MAX_LENGTH {
  281. // 数据包比较大,通过后面的4位长度来表示实际长度
  282. buf = make([]byte, 4)
  283. _, err := io.ReadFull(c.conn, buf)
  284. if err != nil {
  285. return nil, err
  286. }
  287. if c.cipher != nil {
  288. c.cipher.Decrypt(buf, buf)
  289. }
  290. dlen = binary.BigEndian.Uint32(buf)
  291. if dlen < MAX_LENGTH || dlen > MAX2_LENGTH {
  292. return nil, errors.New("wrong length in read message")
  293. }
  294. }
  295. // 读取指定长度的数据
  296. buf = make([]byte, dlen+1) // 最后一个是crc的值
  297. _, err = io.ReadFull(c.conn, buf)
  298. if err != nil {
  299. return nil, err
  300. }
  301. if c.cipher != nil {
  302. c.cipher.Decrypt(buf, buf)
  303. }
  304. // 检查CRC8
  305. if util.CRC8(buf[:dlen]) != buf[dlen] {
  306. return nil, errors.New("CRC error")
  307. }
  308. return buf[:dlen], nil
  309. }
  310. // 获取Auth信息
  311. // id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
  312. func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
  313. defer func() {
  314. if r := recover(); r != nil {
  315. err = fmt.Errorf("recovered from panic: %v", r)
  316. return
  317. }
  318. }()
  319. msg, err := c.readMessage(c.cf.ReadWait)
  320. if err != nil {
  321. return
  322. }
  323. msgLen := len(msg)
  324. if msgLen < 9 {
  325. err = errors.New("wrong message length")
  326. return
  327. }
  328. index := 0
  329. id := binary.BigEndian.Uint16(msg[index : index+2])
  330. if id != config.ID_AUTH {
  331. err = fmt.Errorf("wrong message id: %d", id)
  332. return
  333. }
  334. index += 2
  335. protoLen := int(msg[index])
  336. if protoLen < 2 {
  337. err = errors.New("wrong proto length")
  338. return
  339. }
  340. index++
  341. proto = string(msg[index : index+protoLen])
  342. if proto != PROTO {
  343. err = fmt.Errorf("wrong proto: %s", proto)
  344. return
  345. }
  346. index += protoLen
  347. version = msg[index]
  348. if version != VERSION {
  349. err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
  350. return
  351. }
  352. index++
  353. c.compress = (msg[index] & 0x01) != 0
  354. index++
  355. channelLen := int(binary.BigEndian.Uint16(msg[index : index+2]))
  356. if channelLen < 2 {
  357. err = errors.New("wrong channel length")
  358. return
  359. }
  360. index += 2
  361. channel = string(msg[index : index+channelLen])
  362. index += channelLen
  363. auth = msg[index:]
  364. return
  365. }
  366. // 发送请求数据包到网络
  367. func (c *Tcp2) WriteRequest(id uint16, cmd string, data []byte) error {
  368. // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则截断
  369. cmdLen := len(cmd)
  370. if cmdLen > 0x7F {
  371. return errors.New("command length is more than 0x7F")
  372. }
  373. dlen := len(data)
  374. if c.compress && dlen > 0 {
  375. compressedData, ok := util.CompressData(data)
  376. ddlen := 2 + 1 + cmdLen + 1 + len(compressedData)
  377. buf, start := c.writeDataLen(ddlen)
  378. index := start
  379. binary.BigEndian.PutUint16(buf[index:index+2], id)
  380. index += 2
  381. buf[index] = byte(cmdLen)
  382. index++
  383. copy(buf[index:index+cmdLen], cmd)
  384. index += cmdLen
  385. if ok {
  386. buf[index] = 0x01
  387. } else {
  388. buf[index] = 0
  389. }
  390. index++
  391. copy(buf[index:], compressedData)
  392. buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
  393. return c.writeMessage(buf)
  394. } else {
  395. ddlen := 2 + 1 + cmdLen + dlen
  396. buf, start := c.writeDataLen(ddlen)
  397. index := start
  398. binary.BigEndian.PutUint16(buf[index:index+2], id)
  399. index += 2
  400. buf[index] = byte(cmdLen)
  401. index++
  402. copy(buf[index:index+cmdLen], cmd)
  403. index += cmdLen
  404. copy(buf[index:], data)
  405. buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
  406. return c.writeMessage(buf)
  407. }
  408. }
  409. // 发送响应数据包到网络
  410. // 网络格式:[id, stateCode, data]
  411. func (c *Tcp2) WriteResponse(id uint16, state uint8, data []byte) error {
  412. dlen := len(data)
  413. if c.compress && dlen > 0 {
  414. compressedData, ok := util.CompressData(data)
  415. ddlen := 2 + 1 + 1 + len(compressedData)
  416. buf, start := c.writeDataLen(ddlen)
  417. index := start
  418. binary.BigEndian.PutUint16(buf[index:index+2], id)
  419. index += 2
  420. buf[index] = state | 0x80
  421. index++
  422. if ok {
  423. buf[index] = 0x001
  424. } else {
  425. buf[index] = 0
  426. }
  427. index++
  428. copy(buf[index:], compressedData)
  429. buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
  430. return c.writeMessage(buf)
  431. } else {
  432. ddlen := 2 + 1 + len(data)
  433. buf, start := c.writeDataLen(ddlen)
  434. index := start
  435. binary.BigEndian.PutUint16(buf[index:index+2], id)
  436. index += 2
  437. buf[index] = state | 0x80
  438. index++
  439. copy(buf[index:], data)
  440. buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
  441. return c.writeMessage(buf)
  442. }
  443. }
  444. // 发送ping包
  445. func (c *Tcp2) WritePing(id uint16) error {
  446. dlen := 2
  447. buf, start := c.writeDataLen(dlen)
  448. index := start
  449. binary.BigEndian.PutUint16(buf[index:index+2], id)
  450. // index += 2
  451. buf[start+dlen] = util.CRC8(buf[start : start+dlen])
  452. return c.writeMessage(buf)
  453. }
  454. // 获取信息
  455. func (c *Tcp2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
  456. msg, err := c.readMessage(deadline)
  457. if err != nil {
  458. return
  459. }
  460. msgLen := len(msg)
  461. id = binary.BigEndian.Uint16(msg[0:2])
  462. // ping信息
  463. if msgLen == 2 {
  464. msgType = conn.PingMsg
  465. return
  466. }
  467. if id > config.ID_MAX {
  468. err = fmt.Errorf("wrong message id: %d", id)
  469. return
  470. }
  471. cmdx := msg[2]
  472. if (cmdx & 0x80) == 0 {
  473. // 请求包
  474. msgType = conn.RequestMsg
  475. cmdLen := int(cmdx)
  476. cmd = string(msg[3 : cmdLen+3])
  477. data = msg[cmdLen+3:]
  478. } else {
  479. // 响应数据包
  480. msgType = conn.ResponseMsg
  481. state = cmdx & 0x7F
  482. data = msg[3:]
  483. }
  484. if c.compress && len(data) > 1 {
  485. data, _ = util.UncompressData(data)
  486. }
  487. return
  488. }
  489. // 获取远程的地址
  490. func (c *Tcp2) RemoteAddr() net.Addr {
  491. return c.conn.RemoteAddr()
  492. }
  493. // 获取本地的地址
  494. func (c *Tcp2) LocalAddr() net.Addr {
  495. return c.conn.LocalAddr()
  496. }
  497. func (c *Tcp2) Close() error {
  498. return c.conn.Close()
  499. }