Bläddra i källkod

add gzip in connect

Joyit 3 veckor sedan
förälder
incheckning
e9086c71c8
7 ändrade filer med 291 tillägg och 112 borttagningar
  1. 2 1
      README.md
  2. 9 6
      conn/README.md
  3. 113 49
      conn/tcp2/tcp2.go
  4. 51 0
      conn/util/compress.go
  5. 114 54
      conn/ws2/ws2.go
  6. 1 1
      examples/client-tcp2.go
  7. 1 1
      examples/client-ws2.go

+ 2 - 1
README.md

@@ -23,8 +23,9 @@
 
 - 建立内存池来分配内存,减少内存碎片
 - 同地址多连接共存,使用不同的连接发送消息,减少延时,提高消息送达可靠性
-- 转发地址定时测试切换回到主服务节点
 
 ## 已经解决的问题
 
+- 增加 gzip 的功能,只需要压缩数据部分
+- 转发地址定时测试切换回到主服务节点
 - 增加订阅中间件,处理验证登录等问题

+ 9 - 6
conn/README.md

@@ -11,10 +11,11 @@ id 号大于 65500 为内部使用。
 
 建立连接后发送第一个数据包,包括下面的内容:
 
-id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
+id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
 
 这个数据包必须是除混淆包(如果有的话)之后的第一个包,如果解析不成功则直接断开连接。
 proto 的字符长度不能超过 255
+option 选项,目前只使用最低位,如果为 1 表示启用数据压缩。是否启用数据压缩由客户端决定,如果客户端不支持数据压缩,那就无法启用。
 channel 的字符长度不能超过 65535
 auth 对应的是应用层的认证,如果没有的话可以为空,具体的认证方式由应用层决定。
 auth 的数据结构由应用层自己决定。
@@ -25,14 +26,16 @@ id(uint16)
 
 ### 请求包
 
-id(uint16)+cmdLen(uint8)+cmd(string)+data([]byte)
+id(uint16)+cmdLen(uint8)+cmd(string)+option(byte)+data([]byte)
+
+option 选项,如果不支持压缩的功能,则选项字节不需要。
+目前只使用最低位,如果为 1 表示当前数据压缩。
 
 ### 响应包
 
-id(uint16)+state(uint8)+data([]byte)
+id(uint16)+state(uint8)+option(byte)+data([]byte)
 
 其中 state 字节的最高位 1,需要 &0x7F 处理。
 
-### 频道
-
-id(65502)+channel(string)
+option 选项,如果不支持压缩的功能,则选项字节不需要。
+目前只使用最低位,如果为 1 表示当前数据压缩。

+ 113 - 49
conn/tcp2/tcp2.go

@@ -24,9 +24,10 @@ const MAX_LENGTH = 0xFFFF
 const MAX2_LENGTH = 0x1FFFFFFF // 500 M,避免申请过大内存
 
 type Tcp2 struct {
-	cf     *config.Config
-	conn   net.Conn
-	cipher *util.Cipher // 记录当前的加解密类
+	cf       *config.Config
+	conn     net.Conn
+	cipher   *util.Cipher // 记录当前的加解密类
+	compress bool         // 是否启用压缩
 }
 
 // 服务端
@@ -186,9 +187,10 @@ func Dial(cf *config.Config, addr string, hash string) (conn.Connect, error) {
 	}
 	// 初始化
 	c := &Tcp2{
-		cf:     cf,
-		conn:   conn,
-		cipher: cipher,
+		cf:       cf,
+		conn:     conn,
+		cipher:   cipher,
+		compress: true,
 	}
 	return c, nil
 }
@@ -244,8 +246,8 @@ func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) {
 	if channelLen > 0xFFFF {
 		return errors.New("length of channel over")
 	}
-	// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
-	dlen := 2 + 1 + protoLen + 1 + 2 + channelLen + len(auth)
+	// id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
+	dlen := 2 + 1 + protoLen + 1 + 1 + 2 + channelLen + len(auth)
 
 	buf, start := c.writeDataLen(dlen)
 	index := start
@@ -260,6 +262,13 @@ func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) {
 	buf[index] = VERSION
 	index++
 
+	if c.compress {
+		buf[index] = 0x01
+	} else {
+		buf[index] = 0
+	}
+	index++
+
 	binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
 	index += 2
 
@@ -322,7 +331,7 @@ func (c *Tcp2) readMessage(deadline int) ([]byte, error) {
 }
 
 // 获取Auth信息
-// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
+// id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
 func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
 	defer func() {
 		if r := recover(); r != nil {
@@ -335,48 +344,51 @@ func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 		return
 	}
 	msgLen := len(msg)
-	if msgLen < 4 {
-		err = errors.New("message length less than 4")
+	if msgLen < 9 {
+		err = errors.New("wrong message length")
 		return
 	}
-	start := 0
-	id := binary.BigEndian.Uint16(msg[start : start+2])
+	index := 0
+	id := binary.BigEndian.Uint16(msg[index : index+2])
 	if id != config.ID_AUTH {
 		err = fmt.Errorf("wrong message id: %d", id)
 		return
 	}
-	start += 2
+	index += 2
 
-	protoLen := int(msg[start])
+	protoLen := int(msg[index])
 	if protoLen < 2 {
 		err = errors.New("wrong proto length")
 		return
 	}
-	start++
-	proto = string(msg[start : start+protoLen])
+	index++
+	proto = string(msg[index : index+protoLen])
 	if proto != PROTO {
 		err = fmt.Errorf("wrong proto: %s", proto)
 		return
 	}
-	start += protoLen
+	index += protoLen
 
-	version = msg[start]
+	version = msg[index]
 	if version != VERSION {
 		err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
 		return
 	}
-	start++
+	index++
+
+	c.compress = (msg[index] & 0x01) != 0
+	index++
 
-	channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
+	channelLen := int(binary.BigEndian.Uint16(msg[index : index+2]))
 	if channelLen < 2 {
 		err = errors.New("wrong channel length")
 		return
 	}
-	start += 2
-	channel = string(msg[start : start+channelLen])
-	start += channelLen
+	index += 2
+	channel = string(msg[index : index+channelLen])
+	index += channelLen
 
-	auth = msg[start:]
+	auth = msg[index:]
 	return
 }
 
@@ -387,33 +399,83 @@ func (c *Tcp2) WriteRequest(id uint16, cmd string, data []byte) error {
 	if cmdLen > 0x7F {
 		return errors.New("command length is more than 0x7F")
 	}
-	dlen := 2 + 1 + cmdLen + len(data)
-	buf, start := c.writeDataLen(dlen)
-	index := start
-	binary.BigEndian.PutUint16(buf[index:index+2], id)
-	index += 2
-	buf[index] = byte(cmdLen)
-	index++
-	copy(buf[index:index+cmdLen], cmd)
-	index += cmdLen
-	copy(buf[index:], data)
-	buf[start+dlen] = util.CRC8(buf[start : start+dlen])
-	return c.writeMessage(buf)
+	dlen := len(data)
+	if c.compress && dlen > 0 {
+		compressedData, ok := util.CompressData(data)
+		ddlen := 2 + 1 + cmdLen + 1 + len(compressedData)
+		buf, start := c.writeDataLen(ddlen)
+		index := start
+		binary.BigEndian.PutUint16(buf[index:index+2], id)
+		index += 2
+
+		buf[index] = byte(cmdLen)
+		index++
+		copy(buf[index:index+cmdLen], cmd)
+		index += cmdLen
+
+		if ok {
+			buf[index] = 0x01
+		} else {
+			buf[index] = 0
+		}
+		index++
+
+		copy(buf[index:], compressedData)
+		buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
+		return c.writeMessage(buf)
+	} else {
+		ddlen := 2 + 1 + cmdLen + dlen
+		buf, start := c.writeDataLen(ddlen)
+		index := start
+		binary.BigEndian.PutUint16(buf[index:index+2], id)
+		index += 2
+		buf[index] = byte(cmdLen)
+		index++
+		copy(buf[index:index+cmdLen], cmd)
+		index += cmdLen
+		copy(buf[index:], data)
+		buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
+		return c.writeMessage(buf)
+	}
 }
 
 // 发送响应数据包到网络
 // 网络格式:[id, stateCode, data]
 func (c *Tcp2) WriteResponse(id uint16, state uint8, data []byte) error {
-	dlen := 2 + 1 + len(data)
-	buf, start := c.writeDataLen(dlen)
-	index := start
-	binary.BigEndian.PutUint16(buf[index:index+2], id)
-	index += 2
-	buf[index] = state | 0x80
-	index++
-	copy(buf[index:], data)
-	buf[start+dlen] = util.CRC8(buf[start : start+dlen])
-	return c.writeMessage(buf)
+	dlen := len(data)
+	if c.compress && dlen > 0 {
+		compressedData, ok := util.CompressData(data)
+		ddlen := 2 + 1 + 1 + len(compressedData)
+		buf, start := c.writeDataLen(ddlen)
+		index := start
+		binary.BigEndian.PutUint16(buf[index:index+2], id)
+		index += 2
+
+		buf[index] = state | 0x80
+		index++
+
+		if ok {
+			buf[index] = 0x001
+		} else {
+			buf[index] = 0
+		}
+		index++
+
+		copy(buf[index:], compressedData)
+		buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
+		return c.writeMessage(buf)
+	} else {
+		ddlen := 2 + 1 + len(data)
+		buf, start := c.writeDataLen(ddlen)
+		index := start
+		binary.BigEndian.PutUint16(buf[index:index+2], id)
+		index += 2
+		buf[index] = state | 0x80
+		index++
+		copy(buf[index:], data)
+		buf[start+ddlen] = util.CRC8(buf[start : start+ddlen])
+		return c.writeMessage(buf)
+	}
 }
 
 // 发送ping包
@@ -453,14 +515,16 @@ func (c *Tcp2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd s
 		cmdLen := int(cmdx)
 		cmd = string(msg[3 : cmdLen+3])
 		data = msg[cmdLen+3:]
-		return
 	} else {
 		// 响应数据包
 		msgType = conn.ResponseMsg
 		state = cmdx & 0x7F
 		data = msg[3:]
-		return
 	}
+	if c.compress && len(data) > 1 {
+		data, _ = util.UncompressData(data)
+	}
+	return
 }
 
 // 获取远程的地址

+ 51 - 0
conn/util/compress.go

@@ -0,0 +1,51 @@
+package util
+
+import (
+	"bytes"
+	"compress/gzip"
+	"io"
+	"log"
+)
+
+// 压缩数据
+// 注意:如果没有启用压缩,和数据为空不要调用该函数
+func CompressData(data []byte) (out []byte, ok bool) {
+	var buf bytes.Buffer
+	zw := gzip.NewWriter(&buf)
+	if _, err := zw.Write(data); err != nil {
+		log.Println("[CompressData Write ERROR]", err)
+		return data, false
+	}
+	if err := zw.Close(); err != nil {
+		log.Println("[CompressData Close ERROR]", err)
+		return data, false
+	}
+	compressedData := buf.Bytes()
+	if len(compressedData) >= len(data) {
+		// log.Println("user origin data:", string(data))
+		return data, false
+	} else {
+		return compressedData, true
+	}
+}
+
+// 解压缩
+func UncompressData(compressedData []byte) (data []byte, ok bool) {
+	// 没有压缩,直接返回
+	if compressedData[0]&0x01 != 0 {
+		zr, err := gzip.NewReader(bytes.NewReader(compressedData[1:]))
+		if err != nil {
+			log.Println("[UncompressData NewReader ERROR]", err)
+			return compressedData, false
+		}
+		defer zr.Close()
+		uncompressedData, err := io.ReadAll(zr)
+		if err != nil {
+			log.Println("[UncompressData ReadAll ERROR]", err)
+			return compressedData, false
+		}
+		return uncompressedData, true
+	} else {
+		return compressedData[1:], true
+	}
+}

+ 114 - 54
conn/ws2/ws2.go

@@ -23,9 +23,10 @@ const PROTO_STL string = "wss"
 const VERSION uint8 = 2
 
 type Ws2 struct {
-	cf     *config.Config
-	conn   *websocket.Conn
-	cipher *util.Cipher // 记录当前的加解密类,可以保证在没有ssl的情况下数据安全
+	cf       *config.Config
+	conn     *websocket.Conn
+	cipher   *util.Cipher // 记录当前的加解密类,可以保证在没有ssl的情况下数据安全
+	compress bool         // 是否启用压缩
 }
 
 var upgrader = websocket.Upgrader{} // use default options
@@ -192,9 +193,10 @@ func Dial(cf *config.Config, scheme string, addr string, path string, hash strin
 		return nil, err
 	}
 	ws := &Ws2{
-		cf:     cf,
-		conn:   conn,
-		cipher: cipher,
+		cf:       cf,
+		conn:     conn,
+		cipher:   cipher,
+		compress: true,
 	}
 	return ws, nil
 }
@@ -220,28 +222,40 @@ func (c *Ws2) WriteAuthInfo(channel string, auth []byte) (err error) {
 	if channelLen > 0xFFFF {
 		return errors.New("length of channel over")
 	}
-	// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
-	dlen := 2 + 1 + protoLen + 1 + 2 + channelLen + len(auth)
-	start := 0
+	// id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
+	dlen := 2 + 1 + protoLen + 1 + 1 + 2 + channelLen + len(auth)
+	index := 0
+
 	buf := make([]byte, dlen)
-	binary.BigEndian.PutUint16(buf[start:start+2], config.ID_AUTH)
-	start += 2
-	buf[start] = byte(protoLen)
-	start++
-	copy(buf[start:start+protoLen], []byte(PROTO))
-	start += protoLen
-	buf[start] = VERSION
-	start++
-	binary.BigEndian.PutUint16(buf[start:start+2], uint16(channelLen))
-	start += 2
-	copy(buf[start:start+channelLen], []byte(channel))
-	start += channelLen
-	copy(buf[start:], auth)
+	binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH)
+	index += 2
+
+	buf[index] = byte(protoLen)
+	index++
+	copy(buf[index:index+protoLen], []byte(PROTO))
+	index += protoLen
+
+	buf[index] = VERSION
+	index++
+
+	if c.compress {
+		buf[index] = 0x01
+	} else {
+		buf[index] = 0
+	}
+	index++
+
+	binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
+	index += 2
+	copy(buf[index:index+channelLen], []byte(channel))
+	index += channelLen
+
+	copy(buf[index:], auth)
 	return c.writeMessage(buf)
 }
 
 // 获取Auth信息
-// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
+// id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
 func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
 	defer func() {
 		if r := recover(); r != nil {
@@ -258,52 +272,55 @@ func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 		return
 	}
 	msgLen := len(msg)
-	if msgLen < 4 {
-		err = errors.New("message length less than 4")
+	if msgLen < 9 {
+		err = errors.New("wrong message length")
 		return
 	}
 	// 将读出来的数据进行解密
 	if c.cipher != nil {
 		c.cipher.Decrypt(msg, msg)
 	}
-	start := 0
-	id := binary.BigEndian.Uint16(msg[start : start+2])
+	index := 0
+	id := binary.BigEndian.Uint16(msg[index : index+2])
 	if id != config.ID_AUTH {
 		err = fmt.Errorf("wrong message id: %d", id)
 		return
 	}
-	start += 2
+	index += 2
 
-	protoLen := int(msg[start])
+	protoLen := int(msg[index])
 	if protoLen < 2 {
 		err = errors.New("wrong proto length")
 		return
 	}
-	start++
-	proto = string(msg[start : start+protoLen])
+	index++
+	proto = string(msg[index : index+protoLen])
 	if proto != PROTO {
 		err = fmt.Errorf("wrong proto: %s", proto)
 		return
 	}
-	start += protoLen
+	index += protoLen
 
-	version = msg[start]
+	version = msg[index]
 	if version != VERSION {
 		err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
 		return
 	}
-	start++
+	index++
 
-	channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
+	c.compress = (msg[index] & 0x01) != 0
+	index++
+
+	channelLen := int(binary.BigEndian.Uint16(msg[index : index+2]))
 	if channelLen < 2 {
 		err = errors.New("wrong channel length")
 		return
 	}
-	start += 2
-	channel = string(msg[start : start+channelLen])
-	start += channelLen
+	index += 2
+	channel = string(msg[index : index+channelLen])
+	index += channelLen
 
-	auth = msg[start:]
+	auth = msg[index:]
 	return
 }
 
@@ -314,24 +331,65 @@ func (c *Ws2) WriteRequest(id uint16, cmd string, data []byte) error {
 	if cmdLen > 0x7F {
 		return errors.New("length of command more than 0x7F")
 	}
-	dlen := 2 + 1 + cmdLen + len(data)
-	buf := make([]byte, dlen) // 申请内存
-	binary.BigEndian.PutUint16(buf[0:2], id)
-	buf[2] = byte(cmdLen)
-	copy(buf[3:], cmd)
-	copy(buf[3+cmdLen:], data)
-	return c.writeMessage(buf)
+	dlen := len(data)
+	if c.compress && dlen > 0 {
+		compressedData, ok := util.CompressData(data)
+		ddlen := 2 + 1 + cmdLen + 1 + len(compressedData)
+		buf := make([]byte, ddlen) // 申请内存
+		index := 0
+		binary.BigEndian.PutUint16(buf[index:index+2], id)
+		index += 2
+
+		buf[index] = byte(cmdLen)
+		index++
+		copy(buf[index:], cmd)
+		index += cmdLen
+
+		if ok {
+			buf[index] = 0x01
+		} else {
+			buf[index] = 0
+		}
+		index++
+
+		copy(buf[index:], compressedData)
+		return c.writeMessage(buf)
+	} else {
+		ddlen := 2 + 1 + cmdLen + dlen
+		buf := make([]byte, ddlen) // 申请内存
+		binary.BigEndian.PutUint16(buf[0:2], id)
+		buf[2] = byte(cmdLen)
+		copy(buf[3:], cmd)
+		copy(buf[3+cmdLen:], data)
+		return c.writeMessage(buf)
+	}
 }
 
 // 发送响应数据包到网络
 // 网络格式:[id, stateCode, data]
 func (c *Ws2) WriteResponse(id uint16, state uint8, data []byte) error {
-	dlen := 2 + 1 + len(data)
-	buf := make([]byte, dlen)
-	binary.BigEndian.PutUint16(buf[0:2], id)
-	buf[2] = state | 0x80
-	copy(buf[3:], data)
-	return c.writeMessage(buf)
+	dlen := len(data)
+	if c.compress && dlen > 0 {
+		compressedData, ok := util.CompressData(data)
+		ddlen := 2 + 1 + 1 + len(compressedData)
+		buf := make([]byte, ddlen)
+		binary.BigEndian.PutUint16(buf[0:2], id)
+		buf[2] = state | 0x80
+		if ok {
+			buf[3] = 0x01
+		} else {
+			buf[3] = 0
+		}
+		copy(buf[4:], compressedData)
+		return c.writeMessage(buf)
+	} else {
+		ddlen := 2 + 1 + dlen
+		buf := make([]byte, ddlen)
+		binary.BigEndian.PutUint16(buf[0:2], id)
+		buf[2] = state | 0x80
+		copy(buf[3:], data)
+		return c.writeMessage(buf)
+	}
 }
 
 // 发送ping包
@@ -379,14 +437,16 @@ func (c *Ws2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd st
 		cmdLen := int(cmdx)
 		cmd = string(msg[3 : cmdLen+3])
 		data = msg[cmdLen+3:]
-		return
 	} else {
 		// 响应数据包
 		msgType = conn.ResponseMsg
 		state = cmdx & 0x7F
 		data = msg[3:]
-		return
 	}
+	if c.compress && len(data) > 1 {
+		data, _ = util.UncompressData(data)
+	}
+	return
 }
 
 // 获取远程的地址

+ 1 - 1
examples/client-tcp2.go

@@ -58,7 +58,7 @@ func main() {
 	}
 
 	// 获取信息
-	rsp := hub.GetOne(regexp.MustCompile("/tinymq/server"), "hello", []byte("hello from client"))
+	rsp := hub.GetOne(regexp.MustCompile("/tinymq/server"), "hello", []byte("hello from client, hello from client, hello from client"))
 	if rsp.State != config.STATE_OK {
 		log.Println("error state:", rsp.State)
 		return

+ 1 - 1
examples/client-ws2.go

@@ -61,7 +61,7 @@ func main() {
 	}
 
 	// 获取信息
-	rsp := hub.GetOne(regexp.MustCompile(remoteChannel), "hello", []byte("hello from client"))
+	rsp := hub.GetOne(regexp.MustCompile(remoteChannel), "hello", []byte("hello from client,hello from client,hello from client"))
 	if rsp.State != config.STATE_OK {
 		log.Println("error state:", rsp.State)
 		return