|
@@ -1,4 +1,4 @@
|
|
|
-package tpv2
|
|
|
+package tcp2
|
|
|
|
|
|
import (
|
|
|
"crypto/rand"
|
|
@@ -16,14 +16,14 @@ import (
|
|
|
"git.me9.top/git/tinymq/conn/util"
|
|
|
)
|
|
|
|
|
|
+const PROTO string = "tcp"
|
|
|
const VERSION uint8 = 2
|
|
|
-const PROTO string = "tp"
|
|
|
|
|
|
// 数据包的最大长度
|
|
|
const MAX_LENGTH = 0xFFFF
|
|
|
const MAX2_LENGTH = 0x1FFFFFFF // 500 M,避免申请过大内存
|
|
|
|
|
|
-type TpConnectV2 struct {
|
|
|
+type Tcp2 struct {
|
|
|
cf *config.Config
|
|
|
conn net.Conn
|
|
|
cipher *util.Cipher // 记录当前的加解密类
|
|
@@ -53,7 +53,7 @@ func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFu
|
|
|
log.Printf("Listening and serving tcp on %s\n", bind)
|
|
|
l, err := net.Listen("tcp", bind)
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 Server ERROR]", err)
|
|
|
+ log.Println("[tcp2 Server ERROR]", err)
|
|
|
return
|
|
|
}
|
|
|
go func(l net.Listener) {
|
|
@@ -66,7 +66,7 @@ func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFu
|
|
|
}
|
|
|
go func(conn net.Conn) {
|
|
|
if ci == nil {
|
|
|
- c := &TpConnectV2{
|
|
|
+ c := &Tcp2{
|
|
|
cf: cf,
|
|
|
conn: conn,
|
|
|
}
|
|
@@ -80,36 +80,36 @@ func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFu
|
|
|
eiv = make([]byte, ci.IvLen)
|
|
|
_, err = rand.Read(eiv)
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 Server rand.Read ERROR]", err)
|
|
|
+ log.Println("[tcp2 Server rand.Read ERROR]", err)
|
|
|
return
|
|
|
}
|
|
|
// 发送 IV
|
|
|
conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
|
|
|
if _, err := conn.Write(eiv); err != nil {
|
|
|
- log.Println("[tpv2 Server conn.Write ERROR]", err)
|
|
|
+ log.Println("[tcp2 Server conn.Write ERROR]", err)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 读取 IV
|
|
|
err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 Server SetReadDeadline ERROR]", err)
|
|
|
+ log.Println("[tcp2 Server SetReadDeadline ERROR]", err)
|
|
|
return
|
|
|
}
|
|
|
div = make([]byte, ci.IvLen)
|
|
|
_, err := io.ReadFull(conn, div)
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 Server ReadFull ERROR]", err)
|
|
|
+ log.Println("[tcp2 Server ReadFull ERROR]", err)
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 NewCipher ERROR]", err)
|
|
|
+ log.Println("[tcp2 NewCipher ERROR]", err)
|
|
|
return
|
|
|
}
|
|
|
// 初始化
|
|
|
- c := &TpConnectV2{
|
|
|
+ c := &Tcp2{
|
|
|
cf: cf,
|
|
|
conn: conn,
|
|
|
cipher: cipher,
|
|
@@ -129,7 +129,7 @@ func Client(cf *config.Config, addr string, hash string) (conn.Connect, error) {
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- c := &TpConnectV2{
|
|
|
+ c := &Tcp2{
|
|
|
cf: cf,
|
|
|
conn: conn,
|
|
|
}
|
|
@@ -156,36 +156,36 @@ func Client(cf *config.Config, addr string, hash string) (conn.Connect, error) {
|
|
|
eiv = make([]byte, ci.IvLen)
|
|
|
_, err = rand.Read(eiv)
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 Client rand.Read ERROR]", err)
|
|
|
+ log.Println("[tcp2 Client rand.Read ERROR]", err)
|
|
|
return nil, err
|
|
|
}
|
|
|
// 发送 IV
|
|
|
conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
|
|
|
if _, err := conn.Write(eiv); err != nil {
|
|
|
- log.Println("[tpv2 Client conn.Write ERROR]", err)
|
|
|
+ log.Println("[tcp2 Client conn.Write ERROR]", err)
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 读取 IV
|
|
|
err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 Client SetReadDeadline ERROR]", err)
|
|
|
+ log.Println("[tcp2 Client SetReadDeadline ERROR]", err)
|
|
|
return nil, err
|
|
|
}
|
|
|
div = make([]byte, ci.IvLen)
|
|
|
_, err := io.ReadFull(conn, div)
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 Client ReadFull ERROR]", err)
|
|
|
+ log.Println("[tcp2 Client ReadFull ERROR]", err)
|
|
|
return nil, err
|
|
|
}
|
|
|
}
|
|
|
cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
|
|
|
if err != nil {
|
|
|
- log.Println("[tpv2 NewCipher ERROR]", err)
|
|
|
+ log.Println("[tcp2 NewCipher ERROR]", err)
|
|
|
return nil, err
|
|
|
}
|
|
|
// 初始化
|
|
|
- c := &TpConnectV2{
|
|
|
+ c := &Tcp2{
|
|
|
cf: cf,
|
|
|
conn: conn,
|
|
|
cipher: cipher,
|
|
@@ -195,7 +195,7 @@ func Client(cf *config.Config, addr string, hash string) (conn.Connect, error) {
|
|
|
|
|
|
// 发送数据到网络
|
|
|
// 如果有加密函数的话会直接修改源数据
|
|
|
-func (c *TpConnectV2) writeMessage(buf []byte) (err error) {
|
|
|
+func (c *Tcp2) writeMessage(buf []byte) (err error) {
|
|
|
if len(buf) > MAX2_LENGTH {
|
|
|
return fmt.Errorf("data length more than %d", MAX2_LENGTH)
|
|
|
}
|
|
@@ -218,7 +218,7 @@ func (c *TpConnectV2) writeMessage(buf []byte) (err error) {
|
|
|
|
|
|
// 申请内存并写入数据长度信息
|
|
|
// 还多申请一个字节用于保存crc
|
|
|
-func (c *TpConnectV2) writeDataLen(dlen int) (buf []byte, start int) {
|
|
|
+func (c *Tcp2) writeDataLen(dlen int) (buf []byte, start int) {
|
|
|
if dlen >= MAX_LENGTH {
|
|
|
buf = make([]byte, dlen+2+4+1)
|
|
|
start = 2 + 4
|
|
@@ -234,7 +234,7 @@ func (c *TpConnectV2) writeDataLen(dlen int) (buf []byte, start int) {
|
|
|
|
|
|
// 发送Auth信息
|
|
|
// 建立连接后第一个发送的消息
|
|
|
-func (c *TpConnectV2) WriteAuthInfo(channel string, auth []byte) (err error) {
|
|
|
+func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) {
|
|
|
protoLen := len(PROTO)
|
|
|
channelLen := len(channel)
|
|
|
if channelLen > 0xFFFF {
|
|
@@ -261,7 +261,7 @@ func (c *TpConnectV2) WriteAuthInfo(channel string, auth []byte) (err error) {
|
|
|
}
|
|
|
|
|
|
// 从连接中读取信息
|
|
|
-func (c *TpConnectV2) readMessage(deadline int) ([]byte, error) {
|
|
|
+func (c *Tcp2) readMessage(deadline int) ([]byte, error) {
|
|
|
buf := make([]byte, 2)
|
|
|
err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
|
|
|
if err != nil {
|
|
@@ -313,7 +313,7 @@ func (c *TpConnectV2) readMessage(deadline int) ([]byte, error) {
|
|
|
|
|
|
// 获取Auth信息
|
|
|
// id(uint16)+version(uint8)+proto(string)+channel(string)+auth([]byte)
|
|
|
-func (c *TpConnectV2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
|
|
|
+func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
|
|
|
defer func() {
|
|
|
if r := recover(); r != nil {
|
|
|
err = fmt.Errorf("recovered from panic: %v", r)
|
|
@@ -368,7 +368,7 @@ func (c *TpConnectV2) ReadAuthInfo() (proto string, version uint8, channel strin
|
|
|
}
|
|
|
|
|
|
// 发送请求数据包到网络
|
|
|
-func (c *TpConnectV2) WriteRequest(id uint16, cmd string, data []byte) error {
|
|
|
+func (c *Tcp2) WriteRequest(id uint16, cmd string, data []byte) error {
|
|
|
// 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则截断
|
|
|
cmdLen := len(cmd)
|
|
|
if cmdLen > 0x7F {
|
|
@@ -390,7 +390,7 @@ func (c *TpConnectV2) WriteRequest(id uint16, cmd string, data []byte) error {
|
|
|
|
|
|
// 发送响应数据包到网络
|
|
|
// 网络格式:[id, stateCode, data]
|
|
|
-func (c *TpConnectV2) WriteResponse(id uint16, state uint8, data []byte) error {
|
|
|
+func (c *Tcp2) WriteResponse(id uint16, state uint8, data []byte) error {
|
|
|
dlen := 2 + 1 + len(data)
|
|
|
buf, start := c.writeDataLen(dlen)
|
|
|
index := start
|
|
@@ -404,7 +404,7 @@ func (c *TpConnectV2) WriteResponse(id uint16, state uint8, data []byte) error {
|
|
|
}
|
|
|
|
|
|
// 发送ping包
|
|
|
-func (c *TpConnectV2) WritePing(id uint16) error {
|
|
|
+func (c *Tcp2) WritePing(id uint16) error {
|
|
|
dlen := 2
|
|
|
buf, start := c.writeDataLen(dlen)
|
|
|
index := start
|
|
@@ -415,7 +415,7 @@ func (c *TpConnectV2) WritePing(id uint16) error {
|
|
|
}
|
|
|
|
|
|
// 获取信息
|
|
|
-func (c *TpConnectV2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
|
|
|
+func (c *Tcp2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
|
|
|
msg, err := c.readMessage(deadline)
|
|
|
if err != nil {
|
|
|
return
|
|
@@ -451,15 +451,15 @@ func (c *TpConnectV2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16
|
|
|
}
|
|
|
|
|
|
// 获取远程的地址
|
|
|
-func (c *TpConnectV2) RemoteAddr() net.Addr {
|
|
|
+func (c *Tcp2) RemoteAddr() net.Addr {
|
|
|
return c.conn.RemoteAddr()
|
|
|
}
|
|
|
|
|
|
// 获取本地的地址
|
|
|
-func (c *TpConnectV2) LocalAddr() net.Addr {
|
|
|
+func (c *Tcp2) LocalAddr() net.Addr {
|
|
|
return c.conn.LocalAddr()
|
|
|
}
|
|
|
|
|
|
-func (c *TpConnectV2) Close() error {
|
|
|
+func (c *Tcp2) Close() error {
|
|
|
return c.conn.Close()
|
|
|
}
|