package tcp2 import ( "crypto/rand" "encoding/binary" "errors" "fmt" "io" "log" "net" "strings" "time" "git.me9.top/git/tinymq/config" "git.me9.top/git/tinymq/conn" "git.me9.top/git/tinymq/conn/util" ) const PROTO string = "tcp" const VERSION uint8 = 2 // 数据包的最大长度 const MAX_LENGTH = 0xFFFF const MAX2_LENGTH = 0x1FFFFFFF // 500 M,避免申请过大内存 type Tcp2 struct { cf *config.Config conn net.Conn cipher *util.Cipher // 记录当前的加解密类 } // 服务端 // hash 格式 encryptMethod:encryptKey func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFunc) (err error) { var ci *util.CipherInfo var encryptKey string if hash != "" { i := strings.Index(hash, ":") if i <= 0 { err = errors.New("hash is invalid") return } encryptMethod := hash[0:i] encryptKey = hash[i+1:] if c, ok := util.CipherMethod[encryptMethod]; ok { ci = c } else { return errors.New("Unsupported encryption method: " + encryptMethod) } } log.Printf("Listening and serving tcp on %s\n", bind) l, err := net.Listen("tcp", bind) if err != nil { log.Println("[tcp2 Server ERROR]", err) return } go func(l net.Listener) { defer l.Close() for { conn, err := l.Accept() if err != nil { log.Println("[accept ERROR]", err) return } go func(conn net.Conn) { if ci == nil { c := &Tcp2{ cf: cf, conn: conn, } fn(c) return } var eiv []byte var div []byte if ci.IvLen > 0 { // 服务端 IV eiv = make([]byte, ci.IvLen) _, err = rand.Read(eiv) if err != nil { log.Println("[tcp2 Server rand.Read ERROR]", err) return } // 发送 IV conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond)) if _, err := conn.Write(eiv); err != nil { log.Println("[tcp2 Server conn.Write ERROR]", err) return } // 读取 IV err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait))) if err != nil { log.Println("[tcp2 Server SetReadDeadline ERROR]", err) return } div = make([]byte, ci.IvLen) _, err := io.ReadFull(conn, div) if err != nil { log.Println("[tcp2 Server ReadFull ERROR]", err) return } } cipher, err := util.NewCipher(ci, encryptKey, eiv, div) if err != nil { log.Println("[tcp2 NewCipher ERROR]", err) return } // 初始化 c := &Tcp2{ cf: cf, conn: conn, cipher: cipher, } fn(c) }(conn) } }(l) return } // 客户端,新建一个连接 func Client(cf *config.Config, addr string, hash string) (conn.Connect, error) { // 没有加密的情况 if hash == "" { conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond) if err != nil { return nil, err } c := &Tcp2{ cf: cf, conn: conn, } return c, nil } i := strings.Index(hash, ":") if i <= 0 { return nil, errors.New("hash is invalid") } encryptMethod := hash[0:i] encryptKey := hash[i+1:] ci, ok := util.CipherMethod[encryptMethod] if !ok { return nil, errors.New("Unsupported encryption method: " + encryptMethod) } conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond) if err != nil { return nil, err } var eiv []byte var div []byte if ci.IvLen > 0 { // 客户端 IV eiv = make([]byte, ci.IvLen) _, err = rand.Read(eiv) if err != nil { log.Println("[tcp2 Client rand.Read ERROR]", err) return nil, err } // 发送 IV conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond)) if _, err := conn.Write(eiv); err != nil { log.Println("[tcp2 Client conn.Write ERROR]", err) return nil, err } // 读取 IV err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait))) if err != nil { log.Println("[tcp2 Client SetReadDeadline ERROR]", err) return nil, err } div = make([]byte, ci.IvLen) _, err := io.ReadFull(conn, div) if err != nil { log.Println("[tcp2 Client ReadFull ERROR]", err) return nil, err } } cipher, err := util.NewCipher(ci, encryptKey, eiv, div) if err != nil { log.Println("[tcp2 NewCipher ERROR]", err) return nil, err } // 初始化 c := &Tcp2{ cf: cf, conn: conn, cipher: cipher, } return c, nil } // 发送数据到网络 // 如果有加密函数的话会直接修改源数据 func (c *Tcp2) writeMessage(buf []byte) (err error) { if len(buf) > MAX2_LENGTH { return fmt.Errorf("data length more than %d", MAX2_LENGTH) } if c.cipher != nil { c.cipher.Encrypt(buf, buf) } c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.cf.WriteWait) * time.Millisecond)) for { n, err := c.conn.Write(buf) if err != nil { return err } if n < len(buf) { buf = buf[n:] } else { return nil } } } // 申请内存并写入数据长度信息 // 还多申请一个字节用于保存crc func (c *Tcp2) writeDataLen(dlen int) (buf []byte, start int) { if dlen >= MAX_LENGTH { buf = make([]byte, dlen+2+4+1) start = 2 + 4 binary.BigEndian.PutUint16(buf[:2], MAX_LENGTH) binary.BigEndian.PutUint32(buf[2:6], uint32(dlen)) } else { buf = make([]byte, dlen+2+1) start = 2 binary.BigEndian.PutUint16(buf[:2], uint16(dlen)) } return } // 发送Auth信息 // 建立连接后第一个发送的消息 func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) { protoLen := len(PROTO) channelLen := len(channel) if channelLen > 0xFFFF { return errors.New("length of channel over") } dlen := 2 + 1 + 1 + protoLen + 2 + channelLen + len(auth) buf, start := c.writeDataLen(dlen) index := start binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH) index += 2 buf[index] = VERSION index++ buf[index] = byte(protoLen) index++ copy(buf[index:index+protoLen], []byte(PROTO)) index += protoLen binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen)) index += 2 copy(buf[index:index+channelLen], []byte(channel)) index += channelLen copy(buf[index:], auth) buf[start+dlen] = util.CRC8(buf[start : start+dlen]) return c.writeMessage(buf) } // 从连接中读取信息 func (c *Tcp2) readMessage(deadline int) ([]byte, error) { buf := make([]byte, 2) err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline))) if err != nil { return nil, err } // 读取数据流长度 _, err = io.ReadFull(c.conn, buf) if err != nil { return nil, err } // 将读出来的数据进行解密 if c.cipher != nil { c.cipher.Decrypt(buf, buf) } dlen := uint32(binary.BigEndian.Uint16(buf)) if dlen < 2 { return nil, errors.New("length is less to 2") } if dlen >= MAX_LENGTH { // 数据包比较大,通过后面的4位长度来表示实际长度 buf = make([]byte, 4) _, err := io.ReadFull(c.conn, buf) if err != nil { return nil, err } if c.cipher != nil { c.cipher.Decrypt(buf, buf) } dlen = binary.BigEndian.Uint32(buf) if dlen < MAX_LENGTH || dlen > MAX2_LENGTH { return nil, errors.New("wrong length in read message") } } // 读取指定长度的数据 buf = make([]byte, dlen+1) // 最后一个是crc的值 _, err = io.ReadFull(c.conn, buf) if err != nil { return nil, err } if c.cipher != nil { c.cipher.Decrypt(buf, buf) } // 检查CRC8 if util.CRC8(buf[:dlen]) != buf[dlen] { return nil, errors.New("CRC error") } return buf[:dlen], nil } // 获取Auth信息 // id(uint16)+version(uint8)+proto(string)+channel(string)+auth([]byte) func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("recovered from panic: %v", r) return } }() msg, err := c.readMessage(c.cf.ReadWait) if err != nil { return } msgLen := len(msg) if msgLen < 4 { err = errors.New("message length less than 4") return } start := 0 id := binary.BigEndian.Uint16(msg[start : start+2]) if id != config.ID_AUTH { err = fmt.Errorf("wrong message id: %d", id) return } start += 2 version = msg[start] if version != VERSION { err = fmt.Errorf("require version %d, get version: %d", VERSION, version) return } start++ protoLen := int(msg[start]) if protoLen < 2 { err = errors.New("wrong proto length") return } start++ proto = string(msg[start : start+protoLen]) if proto != PROTO { err = fmt.Errorf("wrong proto: %s", proto) return } start += protoLen channelLen := int(binary.BigEndian.Uint16(msg[start : start+2])) if channelLen < 2 { err = errors.New("wrong channel length") return } start += 2 channel = string(msg[start : start+channelLen]) start += channelLen auth = msg[start:] return } // 发送请求数据包到网络 func (c *Tcp2) WriteRequest(id uint16, cmd string, data []byte) error { // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则截断 cmdLen := len(cmd) if cmdLen > 0x7F { return errors.New("command length is more than 0x7F") } dlen := 2 + 1 + cmdLen + len(data) buf, start := c.writeDataLen(dlen) index := start binary.BigEndian.PutUint16(buf[index:index+2], id) index += 2 buf[index] = byte(cmdLen) index++ copy(buf[index:index+cmdLen], cmd) index += cmdLen copy(buf[index:], data) buf[start+dlen] = util.CRC8(buf[start : start+dlen]) return c.writeMessage(buf) } // 发送响应数据包到网络 // 网络格式:[id, stateCode, data] func (c *Tcp2) WriteResponse(id uint16, state uint8, data []byte) error { dlen := 2 + 1 + len(data) buf, start := c.writeDataLen(dlen) index := start binary.BigEndian.PutUint16(buf[index:index+2], id) index += 2 buf[index] = state | 0x80 index++ copy(buf[index:], data) buf[start+dlen] = util.CRC8(buf[start : start+dlen]) return c.writeMessage(buf) } // 发送ping包 func (c *Tcp2) WritePing(id uint16) error { dlen := 2 buf, start := c.writeDataLen(dlen) index := start binary.BigEndian.PutUint16(buf[index:index+2], id) // index += 2 buf[start+dlen] = util.CRC8(buf[start : start+dlen]) return c.writeMessage(buf) } // 获取信息 func (c *Tcp2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) { msg, err := c.readMessage(deadline) if err != nil { return } msgLen := len(msg) id = binary.BigEndian.Uint16(msg[0:2]) // ping信息 if msgLen == 2 { msgType = conn.PingMsg return } if id > config.ID_MAX { err = fmt.Errorf("wrong message id: %d", id) return } cmdx := msg[2] if (cmdx & 0x80) == 0 { // 请求包 msgType = conn.RequestMsg cmdLen := int(cmdx) cmd = string(msg[3 : cmdLen+3]) data = msg[cmdLen+3:] return } else { // 响应数据包 msgType = conn.ResponseMsg state = cmdx & 0x7F data = msg[3:] return } } // 获取远程的地址 func (c *Tcp2) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // 获取本地的地址 func (c *Tcp2) LocalAddr() net.Addr { return c.conn.LocalAddr() } func (c *Tcp2) Close() error { return c.conn.Close() }