tcp2.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. package tcp2
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "log"
  9. "net"
  10. "strings"
  11. "time"
  12. "git.me9.top/git/tinymq/config"
  13. "git.me9.top/git/tinymq/conn"
  14. "git.me9.top/git/tinymq/conn/util"
  15. )
  16. const PROTO string = "tcp"
  17. const VERSION uint8 = 2
  18. // 数据包的最大长度
  19. const MAX_LENGTH = 0xFFFF
  20. const MAX2_LENGTH = 0x1FFFFFFF // 500 M,避免申请过大内存
  21. type Tcp2 struct {
  22. cf *config.Config
  23. conn net.Conn
  24. cipher *util.Cipher // 记录当前的加解密类
  25. }
  26. // 服务端
  27. // hash 格式 encryptMethod:encryptKey
  28. func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFunc) (err error) {
  29. var ci *util.CipherInfo
  30. var encryptKey string
  31. if hash != "" {
  32. i := strings.Index(hash, ":")
  33. if i <= 0 {
  34. err = errors.New("hash is invalid")
  35. return
  36. }
  37. encryptMethod := hash[0:i]
  38. encryptKey = hash[i+1:]
  39. if c, ok := util.CipherMethod[encryptMethod]; ok {
  40. ci = c
  41. } else {
  42. return errors.New("Unsupported encryption method: " + encryptMethod)
  43. }
  44. }
  45. log.Printf("Listening and serving tcp on %s\n", bind)
  46. l, err := net.Listen("tcp", bind)
  47. if err != nil {
  48. log.Println("[tcp2 Server ERROR]", err)
  49. return
  50. }
  51. go func(l net.Listener) {
  52. defer l.Close()
  53. for {
  54. conn, err := l.Accept()
  55. if err != nil {
  56. log.Println("[accept ERROR]", err)
  57. return
  58. }
  59. go func(conn net.Conn) {
  60. if ci == nil {
  61. c := &Tcp2{
  62. cf: cf,
  63. conn: conn,
  64. }
  65. fn(c)
  66. return
  67. }
  68. var eiv []byte
  69. var div []byte
  70. if ci.IvLen > 0 {
  71. // 服务端 IV
  72. eiv = make([]byte, ci.IvLen)
  73. _, err = rand.Read(eiv)
  74. if err != nil {
  75. log.Println("[tcp2 Server rand.Read ERROR]", err)
  76. return
  77. }
  78. // 发送 IV
  79. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  80. if _, err := conn.Write(eiv); err != nil {
  81. log.Println("[tcp2 Server conn.Write ERROR]", err)
  82. return
  83. }
  84. // 读取 IV
  85. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  86. if err != nil {
  87. log.Println("[tcp2 Server SetReadDeadline ERROR]", err)
  88. return
  89. }
  90. div = make([]byte, ci.IvLen)
  91. _, err := io.ReadFull(conn, div)
  92. if err != nil {
  93. log.Println("[tcp2 Server ReadFull ERROR]", err)
  94. return
  95. }
  96. }
  97. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  98. if err != nil {
  99. log.Println("[tcp2 NewCipher ERROR]", err)
  100. return
  101. }
  102. // 初始化
  103. c := &Tcp2{
  104. cf: cf,
  105. conn: conn,
  106. cipher: cipher,
  107. }
  108. fn(c)
  109. }(conn)
  110. }
  111. }(l)
  112. return
  113. }
  114. // 客户端,新建一个连接
  115. func Dial(cf *config.Config, addr string, hash string) (conn.Connect, error) {
  116. // 没有加密的情况
  117. if hash == "" {
  118. conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond)
  119. if err != nil {
  120. return nil, err
  121. }
  122. c := &Tcp2{
  123. cf: cf,
  124. conn: conn,
  125. }
  126. return c, nil
  127. }
  128. i := strings.Index(hash, ":")
  129. if i <= 0 {
  130. return nil, errors.New("hash is invalid")
  131. }
  132. encryptMethod := hash[0:i]
  133. encryptKey := hash[i+1:]
  134. ci, ok := util.CipherMethod[encryptMethod]
  135. if !ok {
  136. return nil, errors.New("Unsupported encryption method: " + encryptMethod)
  137. }
  138. conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond)
  139. if err != nil {
  140. return nil, err
  141. }
  142. var eiv []byte
  143. var div []byte
  144. if ci.IvLen > 0 {
  145. // 客户端 IV
  146. eiv = make([]byte, ci.IvLen)
  147. _, err = rand.Read(eiv)
  148. if err != nil {
  149. log.Println("[tcp2 Client rand.Read ERROR]", err)
  150. return nil, err
  151. }
  152. // 发送 IV
  153. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  154. if _, err := conn.Write(eiv); err != nil {
  155. log.Println("[tcp2 Client conn.Write ERROR]", err)
  156. return nil, err
  157. }
  158. // 读取 IV
  159. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  160. if err != nil {
  161. log.Println("[tcp2 Client SetReadDeadline ERROR]", err)
  162. return nil, err
  163. }
  164. div = make([]byte, ci.IvLen)
  165. _, err := io.ReadFull(conn, div)
  166. if err != nil {
  167. log.Println("[tcp2 Client ReadFull ERROR]", err)
  168. return nil, err
  169. }
  170. }
  171. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  172. if err != nil {
  173. log.Println("[tcp2 NewCipher ERROR]", err)
  174. return nil, err
  175. }
  176. // 初始化
  177. c := &Tcp2{
  178. cf: cf,
  179. conn: conn,
  180. cipher: cipher,
  181. }
  182. return c, nil
  183. }
  184. // 发送数据到网络
  185. // 如果有加密函数的话会直接修改源数据
  186. func (c *Tcp2) writeMessage(buf []byte) (err error) {
  187. if len(buf) > MAX2_LENGTH {
  188. return fmt.Errorf("data length more than %d", MAX2_LENGTH)
  189. }
  190. if c.cipher != nil {
  191. c.cipher.Encrypt(buf, buf)
  192. }
  193. c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.cf.WriteWait) * time.Millisecond))
  194. for {
  195. n, err := c.conn.Write(buf)
  196. if err != nil {
  197. return err
  198. }
  199. if n < len(buf) {
  200. buf = buf[n:]
  201. } else {
  202. return nil
  203. }
  204. }
  205. }
  206. // 申请内存并写入数据长度信息
  207. // 还多申请一个字节用于保存crc
  208. func (c *Tcp2) writeDataLen(dlen int) (buf []byte, start int) {
  209. if dlen >= MAX_LENGTH {
  210. buf = make([]byte, dlen+2+4+1)
  211. start = 2 + 4
  212. binary.BigEndian.PutUint16(buf[:2], MAX_LENGTH)
  213. binary.BigEndian.PutUint32(buf[2:6], uint32(dlen))
  214. } else {
  215. buf = make([]byte, dlen+2+1)
  216. start = 2
  217. binary.BigEndian.PutUint16(buf[:2], uint16(dlen))
  218. }
  219. return
  220. }
  221. // 发送Auth信息
  222. // 建立连接后第一个发送的消息
  223. func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) {
  224. protoLen := len(PROTO)
  225. if protoLen > 0xFF {
  226. return errors.New("length of protocol over")
  227. }
  228. channelLen := len(channel)
  229. if channelLen > 0xFFFF {
  230. return errors.New("length of channel over")
  231. }
  232. // id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
  233. dlen := 2 + 1 + protoLen + 1 + 2 + channelLen + len(auth)
  234. buf, start := c.writeDataLen(dlen)
  235. index := start
  236. binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH)
  237. index += 2
  238. buf[index] = byte(protoLen)
  239. index++
  240. copy(buf[index:index+protoLen], []byte(PROTO))
  241. index += protoLen
  242. buf[index] = VERSION
  243. index++
  244. binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
  245. index += 2
  246. copy(buf[index:index+channelLen], []byte(channel))
  247. index += channelLen
  248. copy(buf[index:], auth)
  249. buf[start+dlen] = util.CRC8(buf[start : start+dlen])
  250. return c.writeMessage(buf)
  251. }
  252. // 从连接中读取信息
  253. func (c *Tcp2) readMessage(deadline int) ([]byte, error) {
  254. buf := make([]byte, 2)
  255. err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
  256. if err != nil {
  257. return nil, err
  258. }
  259. // 读取数据流长度
  260. _, err = io.ReadFull(c.conn, buf)
  261. if err != nil {
  262. return nil, err
  263. }
  264. // 将读出来的数据进行解密
  265. if c.cipher != nil {
  266. c.cipher.Decrypt(buf, buf)
  267. }
  268. dlen := uint32(binary.BigEndian.Uint16(buf))
  269. if dlen < 2 {
  270. return nil, errors.New("length is less to 2")
  271. }
  272. if dlen >= MAX_LENGTH {
  273. // 数据包比较大,通过后面的4位长度来表示实际长度
  274. buf = make([]byte, 4)
  275. _, err := io.ReadFull(c.conn, buf)
  276. if err != nil {
  277. return nil, err
  278. }
  279. if c.cipher != nil {
  280. c.cipher.Decrypt(buf, buf)
  281. }
  282. dlen = binary.BigEndian.Uint32(buf)
  283. if dlen < MAX_LENGTH || dlen > MAX2_LENGTH {
  284. return nil, errors.New("wrong length in read message")
  285. }
  286. }
  287. // 读取指定长度的数据
  288. buf = make([]byte, dlen+1) // 最后一个是crc的值
  289. _, err = io.ReadFull(c.conn, buf)
  290. if err != nil {
  291. return nil, err
  292. }
  293. if c.cipher != nil {
  294. c.cipher.Decrypt(buf, buf)
  295. }
  296. // 检查CRC8
  297. if util.CRC8(buf[:dlen]) != buf[dlen] {
  298. return nil, errors.New("CRC error")
  299. }
  300. return buf[:dlen], nil
  301. }
  302. // 获取Auth信息
  303. // id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
  304. func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
  305. defer func() {
  306. if r := recover(); r != nil {
  307. err = fmt.Errorf("recovered from panic: %v", r)
  308. return
  309. }
  310. }()
  311. msg, err := c.readMessage(c.cf.ReadWait)
  312. if err != nil {
  313. return
  314. }
  315. msgLen := len(msg)
  316. if msgLen < 4 {
  317. err = errors.New("message length less than 4")
  318. return
  319. }
  320. start := 0
  321. id := binary.BigEndian.Uint16(msg[start : start+2])
  322. if id != config.ID_AUTH {
  323. err = fmt.Errorf("wrong message id: %d", id)
  324. return
  325. }
  326. start += 2
  327. protoLen := int(msg[start])
  328. if protoLen < 2 {
  329. err = errors.New("wrong proto length")
  330. return
  331. }
  332. start++
  333. proto = string(msg[start : start+protoLen])
  334. if proto != PROTO {
  335. err = fmt.Errorf("wrong proto: %s", proto)
  336. return
  337. }
  338. start += protoLen
  339. version = msg[start]
  340. if version != VERSION {
  341. err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
  342. return
  343. }
  344. start++
  345. channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
  346. if channelLen < 2 {
  347. err = errors.New("wrong channel length")
  348. return
  349. }
  350. start += 2
  351. channel = string(msg[start : start+channelLen])
  352. start += channelLen
  353. auth = msg[start:]
  354. return
  355. }
  356. // 发送请求数据包到网络
  357. func (c *Tcp2) WriteRequest(id uint16, cmd string, data []byte) error {
  358. // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则截断
  359. cmdLen := len(cmd)
  360. if cmdLen > 0x7F {
  361. return errors.New("command length is more than 0x7F")
  362. }
  363. dlen := 2 + 1 + cmdLen + len(data)
  364. buf, start := c.writeDataLen(dlen)
  365. index := start
  366. binary.BigEndian.PutUint16(buf[index:index+2], id)
  367. index += 2
  368. buf[index] = byte(cmdLen)
  369. index++
  370. copy(buf[index:index+cmdLen], cmd)
  371. index += cmdLen
  372. copy(buf[index:], data)
  373. buf[start+dlen] = util.CRC8(buf[start : start+dlen])
  374. return c.writeMessage(buf)
  375. }
  376. // 发送响应数据包到网络
  377. // 网络格式:[id, stateCode, data]
  378. func (c *Tcp2) WriteResponse(id uint16, state uint8, data []byte) error {
  379. dlen := 2 + 1 + len(data)
  380. buf, start := c.writeDataLen(dlen)
  381. index := start
  382. binary.BigEndian.PutUint16(buf[index:index+2], id)
  383. index += 2
  384. buf[index] = state | 0x80
  385. index++
  386. copy(buf[index:], data)
  387. buf[start+dlen] = util.CRC8(buf[start : start+dlen])
  388. return c.writeMessage(buf)
  389. }
  390. // 发送ping包
  391. func (c *Tcp2) WritePing(id uint16) error {
  392. dlen := 2
  393. buf, start := c.writeDataLen(dlen)
  394. index := start
  395. binary.BigEndian.PutUint16(buf[index:index+2], id)
  396. // index += 2
  397. buf[start+dlen] = util.CRC8(buf[start : start+dlen])
  398. return c.writeMessage(buf)
  399. }
  400. // 获取信息
  401. func (c *Tcp2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
  402. msg, err := c.readMessage(deadline)
  403. if err != nil {
  404. return
  405. }
  406. msgLen := len(msg)
  407. id = binary.BigEndian.Uint16(msg[0:2])
  408. // ping信息
  409. if msgLen == 2 {
  410. msgType = conn.PingMsg
  411. return
  412. }
  413. if id > config.ID_MAX {
  414. err = fmt.Errorf("wrong message id: %d", id)
  415. return
  416. }
  417. cmdx := msg[2]
  418. if (cmdx & 0x80) == 0 {
  419. // 请求包
  420. msgType = conn.RequestMsg
  421. cmdLen := int(cmdx)
  422. cmd = string(msg[3 : cmdLen+3])
  423. data = msg[cmdLen+3:]
  424. return
  425. } else {
  426. // 响应数据包
  427. msgType = conn.ResponseMsg
  428. state = cmdx & 0x7F
  429. data = msg[3:]
  430. return
  431. }
  432. }
  433. // 获取远程的地址
  434. func (c *Tcp2) RemoteAddr() net.Addr {
  435. return c.conn.RemoteAddr()
  436. }
  437. // 获取本地的地址
  438. func (c *Tcp2) LocalAddr() net.Addr {
  439. return c.conn.LocalAddr()
  440. }
  441. func (c *Tcp2) Close() error {
  442. return c.conn.Close()
  443. }