Browse Source

change protocol and version position

Joyit 3 tuần trước cách đây
mục cha
commit
199d80ede0
3 tập tin đã thay đổi với 41 bổ sung21 xóa
  1. 1 1
      conn/README.md
  2. 23 10
      conn/tcp2/tcp2.go
  3. 17 10
      conn/ws2/ws2.go

+ 1 - 1
conn/README.md

@@ -11,7 +11,7 @@ id 号大于 65500 为内部使用。
 
 建立连接后发送第一个数据包,包括下面的内容:
 
-id(uint16)+version(uint8)+proto(string)+channel(string)+auth([]byte)
+id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
 
 这个数据包必须是除混淆包(如果有的话)之后的第一个包,如果解析不成功则直接断开连接。
 proto 的字符长度不能超过 255

+ 23 - 10
conn/tcp2/tcp2.go

@@ -236,23 +236,33 @@ func (c *Tcp2) writeDataLen(dlen int) (buf []byte, start int) {
 // 建立连接后第一个发送的消息
 func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) {
 	protoLen := len(PROTO)
+	if protoLen > 0xFF {
+		return errors.New("length of protocol over")
+	}
+
 	channelLen := len(channel)
 	if channelLen > 0xFFFF {
 		return errors.New("length of channel over")
 	}
-	dlen := 2 + 1 + 1 + protoLen + 2 + channelLen + len(auth)
+	// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
+	dlen := 2 + 1 + protoLen + 1 + 2 + channelLen + len(auth)
+
 	buf, start := c.writeDataLen(dlen)
 	index := start
 	binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH)
 	index += 2
-	buf[index] = VERSION
-	index++
+
 	buf[index] = byte(protoLen)
 	index++
 	copy(buf[index:index+protoLen], []byte(PROTO))
 	index += protoLen
+
+	buf[index] = VERSION
+	index++
+
 	binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
 	index += 2
+
 	copy(buf[index:index+channelLen], []byte(channel))
 	index += channelLen
 	copy(buf[index:], auth)
@@ -312,7 +322,7 @@ func (c *Tcp2) readMessage(deadline int) ([]byte, error) {
 }
 
 // 获取Auth信息
-// id(uint16)+version(uint8)+proto(string)+channel(string)+auth([]byte)
+// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
 func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
 	defer func() {
 		if r := recover(); r != nil {
@@ -337,12 +347,6 @@ func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 	}
 	start += 2
 
-	version = msg[start]
-	if version != VERSION {
-		err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
-		return
-	}
-	start++
 	protoLen := int(msg[start])
 	if protoLen < 2 {
 		err = errors.New("wrong proto length")
@@ -355,6 +359,14 @@ func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 		return
 	}
 	start += protoLen
+
+	version = msg[start]
+	if version != VERSION {
+		err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
+		return
+	}
+	start++
+
 	channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
 	if channelLen < 2 {
 		err = errors.New("wrong channel length")
@@ -363,6 +375,7 @@ func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 	start += 2
 	channel = string(msg[start : start+channelLen])
 	start += channelLen
+
 	auth = msg[start:]
 	return
 }

+ 17 - 10
conn/ws2/ws2.go

@@ -213,21 +213,25 @@ func (c *Ws2) writeMessage(buf []byte) (err error) {
 // 建立连接后第一个发送的消息
 func (c *Ws2) WriteAuthInfo(channel string, auth []byte) (err error) {
 	protoLen := len(PROTO)
+	if protoLen > 0xFF {
+		return errors.New("length of protocol over")
+	}
 	channelLen := len(channel)
 	if channelLen > 0xFFFF {
 		return errors.New("length of channel over")
 	}
-	dlen := 2 + 1 + 1 + protoLen + 2 + channelLen + len(auth)
+	// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
+	dlen := 2 + 1 + protoLen + 1 + 2 + channelLen + len(auth)
 	start := 0
 	buf := make([]byte, dlen)
 	binary.BigEndian.PutUint16(buf[start:start+2], config.ID_AUTH)
 	start += 2
-	buf[start] = VERSION
-	start++
 	buf[start] = byte(protoLen)
 	start++
 	copy(buf[start:start+protoLen], []byte(PROTO))
 	start += protoLen
+	buf[start] = VERSION
+	start++
 	binary.BigEndian.PutUint16(buf[start:start+2], uint16(channelLen))
 	start += 2
 	copy(buf[start:start+channelLen], []byte(channel))
@@ -237,7 +241,7 @@ func (c *Ws2) WriteAuthInfo(channel string, auth []byte) (err error) {
 }
 
 // 获取Auth信息
-// id(uint16)+version(uint8)+proto(string)+channel(string)+auth([]byte)
+// id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
 func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
 	defer func() {
 		if r := recover(); r != nil {
@@ -270,12 +274,6 @@ func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 	}
 	start += 2
 
-	version = msg[start]
-	if version != VERSION {
-		err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
-		return
-	}
-	start++
 	protoLen := int(msg[start])
 	if protoLen < 2 {
 		err = errors.New("wrong proto length")
@@ -288,6 +286,14 @@ func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 		return
 	}
 	start += protoLen
+
+	version = msg[start]
+	if version != VERSION {
+		err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
+		return
+	}
+	start++
+
 	channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
 	if channelLen < 2 {
 		err = errors.New("wrong channel length")
@@ -296,6 +302,7 @@ func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth
 	start += 2
 	channel = string(msg[start : start+channelLen])
 	start += channelLen
+
 	auth = msg[start:]
 	return
 }