123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465 |
- package tcp2
- import (
- "crypto/rand"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "log"
- "net"
- "strings"
- "time"
- "git.me9.top/git/tinymq/config"
- "git.me9.top/git/tinymq/conn"
- "git.me9.top/git/tinymq/conn/util"
- )
- const PROTO string = "tcp"
- const VERSION uint8 = 2
- // 数据包的最大长度
- const MAX_LENGTH = 0xFFFF
- const MAX2_LENGTH = 0x1FFFFFFF // 500 M,避免申请过大内存
- type Tcp2 struct {
- cf *config.Config
- conn net.Conn
- cipher *util.Cipher // 记录当前的加解密类
- }
- // 服务端
- // hash 格式 encryptMethod:encryptKey
- func Server(cf *config.Config, bind string, hash string, fn conn.ServerConnectFunc) (err error) {
- var ci *util.CipherInfo
- var encryptKey string
- if hash != "" {
- i := strings.Index(hash, ":")
- if i <= 0 {
- err = errors.New("hash is invalid")
- return
- }
- encryptMethod := hash[0:i]
- encryptKey = hash[i+1:]
- if c, ok := util.CipherMethod[encryptMethod]; ok {
- ci = c
- } else {
- return errors.New("Unsupported encryption method: " + encryptMethod)
- }
- }
- log.Printf("Listening and serving tcp on %s\n", bind)
- l, err := net.Listen("tcp", bind)
- if err != nil {
- log.Println("[tcp2 Server ERROR]", err)
- return
- }
- go func(l net.Listener) {
- defer l.Close()
- for {
- conn, err := l.Accept()
- if err != nil {
- log.Println("[accept ERROR]", err)
- return
- }
- go func(conn net.Conn) {
- if ci == nil {
- c := &Tcp2{
- cf: cf,
- conn: conn,
- }
- fn(c)
- return
- }
- var eiv []byte
- var div []byte
- if ci.IvLen > 0 {
- // 服务端 IV
- eiv = make([]byte, ci.IvLen)
- _, err = rand.Read(eiv)
- if err != nil {
- log.Println("[tcp2 Server rand.Read ERROR]", err)
- return
- }
- // 发送 IV
- conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
- if _, err := conn.Write(eiv); err != nil {
- log.Println("[tcp2 Server conn.Write ERROR]", err)
- return
- }
- // 读取 IV
- err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
- if err != nil {
- log.Println("[tcp2 Server SetReadDeadline ERROR]", err)
- return
- }
- div = make([]byte, ci.IvLen)
- _, err := io.ReadFull(conn, div)
- if err != nil {
- log.Println("[tcp2 Server ReadFull ERROR]", err)
- return
- }
- }
- cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
- if err != nil {
- log.Println("[tcp2 NewCipher ERROR]", err)
- return
- }
- // 初始化
- c := &Tcp2{
- cf: cf,
- conn: conn,
- cipher: cipher,
- }
- fn(c)
- }(conn)
- }
- }(l)
- return
- }
- // 客户端,新建一个连接
- func Dial(cf *config.Config, addr string, hash string) (conn.Connect, error) {
- // 没有加密的情况
- if hash == "" {
- conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond)
- if err != nil {
- return nil, err
- }
- c := &Tcp2{
- cf: cf,
- conn: conn,
- }
- return c, nil
- }
- i := strings.Index(hash, ":")
- if i <= 0 {
- return nil, errors.New("hash is invalid")
- }
- encryptMethod := hash[0:i]
- encryptKey := hash[i+1:]
- ci, ok := util.CipherMethod[encryptMethod]
- if !ok {
- return nil, errors.New("Unsupported encryption method: " + encryptMethod)
- }
- conn, err := net.DialTimeout("tcp", addr, time.Duration(cf.ConnectTimeout)*time.Millisecond)
- if err != nil {
- return nil, err
- }
- var eiv []byte
- var div []byte
- if ci.IvLen > 0 {
- // 客户端 IV
- eiv = make([]byte, ci.IvLen)
- _, err = rand.Read(eiv)
- if err != nil {
- log.Println("[tcp2 Client rand.Read ERROR]", err)
- return nil, err
- }
- // 发送 IV
- conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
- if _, err := conn.Write(eiv); err != nil {
- log.Println("[tcp2 Client conn.Write ERROR]", err)
- return nil, err
- }
- // 读取 IV
- err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
- if err != nil {
- log.Println("[tcp2 Client SetReadDeadline ERROR]", err)
- return nil, err
- }
- div = make([]byte, ci.IvLen)
- _, err := io.ReadFull(conn, div)
- if err != nil {
- log.Println("[tcp2 Client ReadFull ERROR]", err)
- return nil, err
- }
- }
- cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
- if err != nil {
- log.Println("[tcp2 NewCipher ERROR]", err)
- return nil, err
- }
- // 初始化
- c := &Tcp2{
- cf: cf,
- conn: conn,
- cipher: cipher,
- }
- return c, nil
- }
- // 发送数据到网络
- // 如果有加密函数的话会直接修改源数据
- func (c *Tcp2) writeMessage(buf []byte) (err error) {
- if len(buf) > MAX2_LENGTH {
- return fmt.Errorf("data length more than %d", MAX2_LENGTH)
- }
- if c.cipher != nil {
- c.cipher.Encrypt(buf, buf)
- }
- c.conn.SetWriteDeadline(time.Now().Add(time.Duration(c.cf.WriteWait) * time.Millisecond))
- for {
- n, err := c.conn.Write(buf)
- if err != nil {
- return err
- }
- if n < len(buf) {
- buf = buf[n:]
- } else {
- return nil
- }
- }
- }
- // 申请内存并写入数据长度信息
- // 还多申请一个字节用于保存crc
- func (c *Tcp2) writeDataLen(dlen int) (buf []byte, start int) {
- if dlen >= MAX_LENGTH {
- buf = make([]byte, dlen+2+4+1)
- start = 2 + 4
- binary.BigEndian.PutUint16(buf[:2], MAX_LENGTH)
- binary.BigEndian.PutUint32(buf[2:6], uint32(dlen))
- } else {
- buf = make([]byte, dlen+2+1)
- start = 2
- binary.BigEndian.PutUint16(buf[:2], uint16(dlen))
- }
- return
- }
- // 发送Auth信息
- // 建立连接后第一个发送的消息
- func (c *Tcp2) WriteAuthInfo(channel string, auth []byte) (err error) {
- protoLen := len(PROTO)
- channelLen := len(channel)
- if channelLen > 0xFFFF {
- return errors.New("length of channel over")
- }
- dlen := 2 + 1 + 1 + protoLen + 2 + channelLen + len(auth)
- buf, start := c.writeDataLen(dlen)
- index := start
- binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH)
- index += 2
- buf[index] = VERSION
- index++
- buf[index] = byte(protoLen)
- index++
- copy(buf[index:index+protoLen], []byte(PROTO))
- index += protoLen
- binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
- index += 2
- copy(buf[index:index+channelLen], []byte(channel))
- index += channelLen
- copy(buf[index:], auth)
- buf[start+dlen] = util.CRC8(buf[start : start+dlen])
- return c.writeMessage(buf)
- }
- // 从连接中读取信息
- func (c *Tcp2) readMessage(deadline int) ([]byte, error) {
- buf := make([]byte, 2)
- err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
- if err != nil {
- return nil, err
- }
- // 读取数据流长度
- _, err = io.ReadFull(c.conn, buf)
- if err != nil {
- return nil, err
- }
- // 将读出来的数据进行解密
- if c.cipher != nil {
- c.cipher.Decrypt(buf, buf)
- }
- dlen := uint32(binary.BigEndian.Uint16(buf))
- if dlen < 2 {
- return nil, errors.New("length is less to 2")
- }
- if dlen >= MAX_LENGTH {
- // 数据包比较大,通过后面的4位长度来表示实际长度
- buf = make([]byte, 4)
- _, err := io.ReadFull(c.conn, buf)
- if err != nil {
- return nil, err
- }
- if c.cipher != nil {
- c.cipher.Decrypt(buf, buf)
- }
- dlen = binary.BigEndian.Uint32(buf)
- if dlen < MAX_LENGTH || dlen > MAX2_LENGTH {
- return nil, errors.New("wrong length in read message")
- }
- }
- // 读取指定长度的数据
- buf = make([]byte, dlen+1) // 最后一个是crc的值
- _, err = io.ReadFull(c.conn, buf)
- if err != nil {
- return nil, err
- }
- if c.cipher != nil {
- c.cipher.Decrypt(buf, buf)
- }
- // 检查CRC8
- if util.CRC8(buf[:dlen]) != buf[dlen] {
- return nil, errors.New("CRC error")
- }
- return buf[:dlen], nil
- }
- // 获取Auth信息
- // id(uint16)+version(uint8)+proto(string)+channel(string)+auth([]byte)
- func (c *Tcp2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
- defer func() {
- if r := recover(); r != nil {
- err = fmt.Errorf("recovered from panic: %v", r)
- return
- }
- }()
- msg, err := c.readMessage(c.cf.ReadWait)
- if err != nil {
- return
- }
- msgLen := len(msg)
- if msgLen < 4 {
- err = errors.New("message length less than 4")
- return
- }
- start := 0
- id := binary.BigEndian.Uint16(msg[start : start+2])
- if id != config.ID_AUTH {
- err = fmt.Errorf("wrong message id: %d", id)
- return
- }
- start += 2
- version = msg[start]
- if version != VERSION {
- err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
- return
- }
- start++
- protoLen := int(msg[start])
- if protoLen < 2 {
- err = errors.New("wrong proto length")
- return
- }
- start++
- proto = string(msg[start : start+protoLen])
- if proto != PROTO {
- err = fmt.Errorf("wrong proto: %s", proto)
- return
- }
- start += protoLen
- channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
- if channelLen < 2 {
- err = errors.New("wrong channel length")
- return
- }
- start += 2
- channel = string(msg[start : start+channelLen])
- start += channelLen
- auth = msg[start:]
- return
- }
- // 发送请求数据包到网络
- func (c *Tcp2) WriteRequest(id uint16, cmd string, data []byte) error {
- // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则截断
- cmdLen := len(cmd)
- if cmdLen > 0x7F {
- return errors.New("command length is more than 0x7F")
- }
- dlen := 2 + 1 + cmdLen + len(data)
- buf, start := c.writeDataLen(dlen)
- index := start
- binary.BigEndian.PutUint16(buf[index:index+2], id)
- index += 2
- buf[index] = byte(cmdLen)
- index++
- copy(buf[index:index+cmdLen], cmd)
- index += cmdLen
- copy(buf[index:], data)
- buf[start+dlen] = util.CRC8(buf[start : start+dlen])
- return c.writeMessage(buf)
- }
- // 发送响应数据包到网络
- // 网络格式:[id, stateCode, data]
- func (c *Tcp2) WriteResponse(id uint16, state uint8, data []byte) error {
- dlen := 2 + 1 + len(data)
- buf, start := c.writeDataLen(dlen)
- index := start
- binary.BigEndian.PutUint16(buf[index:index+2], id)
- index += 2
- buf[index] = state | 0x80
- index++
- copy(buf[index:], data)
- buf[start+dlen] = util.CRC8(buf[start : start+dlen])
- return c.writeMessage(buf)
- }
- // 发送ping包
- func (c *Tcp2) WritePing(id uint16) error {
- dlen := 2
- buf, start := c.writeDataLen(dlen)
- index := start
- binary.BigEndian.PutUint16(buf[index:index+2], id)
- // index += 2
- buf[start+dlen] = util.CRC8(buf[start : start+dlen])
- return c.writeMessage(buf)
- }
- // 获取信息
- func (c *Tcp2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
- msg, err := c.readMessage(deadline)
- if err != nil {
- return
- }
- msgLen := len(msg)
- id = binary.BigEndian.Uint16(msg[0:2])
- // ping信息
- if msgLen == 2 {
- msgType = conn.PingMsg
- return
- }
- if id > config.ID_MAX {
- err = fmt.Errorf("wrong message id: %d", id)
- return
- }
- cmdx := msg[2]
- if (cmdx & 0x80) == 0 {
- // 请求包
- msgType = conn.RequestMsg
- cmdLen := int(cmdx)
- cmd = string(msg[3 : cmdLen+3])
- data = msg[cmdLen+3:]
- return
- } else {
- // 响应数据包
- msgType = conn.ResponseMsg
- state = cmdx & 0x7F
- data = msg[3:]
- return
- }
- }
- // 获取远程的地址
- func (c *Tcp2) RemoteAddr() net.Addr {
- return c.conn.RemoteAddr()
- }
- // 获取本地的地址
- func (c *Tcp2) LocalAddr() net.Addr {
- return c.conn.LocalAddr()
- }
- func (c *Tcp2) Close() error {
- return c.conn.Close()
- }
|