package ws2 import ( "crypto/rand" "encoding/binary" "errors" "fmt" "log" "net" "net/http" "net/url" "strings" "time" "git.me9.top/git/tinymq/config" "git.me9.top/git/tinymq/conn" "git.me9.top/git/tinymq/conn/util" "github.com/gorilla/websocket" ) const PROTO string = "ws" const PROTO_STL string = "wss" const VERSION uint8 = 2 type Ws2 struct { cf *config.Config conn *websocket.Conn cipher *util.Cipher // 记录当前的加解密类,可以保证在没有ssl的情况下数据安全 } var upgrader = websocket.Upgrader{} // use default options // websocket 服务 // 如果有绑定参数,则进行绑定操作代码 func Server(cf *config.Config, bind string, path string, hash string, fn conn.ServerConnectFunc) (err error) { var ci *util.CipherInfo var encryptKey string if hash != "" { i := strings.Index(hash, ":") if i <= 0 { return errors.New("hash is invalid") } encryptMethod := hash[0:i] encryptKey = hash[i+1:] if c, ok := util.CipherMethod[encryptMethod]; ok { ci = c } else { return errors.New("Unsupported encryption method: " + encryptMethod) } } http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println("[ws2 Server Upgrade ERROR]", err) return } if ci == nil { ws := &Ws2{ cf: cf, conn: conn, } fn(ws) return } var eiv []byte var div []byte if ci.IvLen > 0 { // 服务端 IV eiv = make([]byte, ci.IvLen) _, err = rand.Read(eiv) if err != nil { log.Println("[ws2 Server rand.Read ERROR]", err) return } // 发送 IV conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond)) if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil { log.Println("[ws2 Server conn.Write ERROR]", err) return } // 读取 IV err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait))) if err != nil { log.Println("[ws2 Server SetReadDeadline ERROR]", err) return } _, div, err = conn.ReadMessage() if err != nil { log.Println("[ws2 Server ReadFull ERROR]", err) return } } cipher, err := util.NewCipher(ci, encryptKey, eiv, div) if err != nil { log.Println("[ws2 NewCipher ERROR]", err) return } ws := &Ws2{ cf: cf, conn: conn, cipher: cipher, } fn(ws) }) if bind != "" { go func() (err error) { defer func() { if err != nil { log.Fatal(err) } }() log.Printf("Listening and serving Websocket on %s\n", bind) // 暂时使用全局的方式,后面有需求再修改 // 而且还没有 https 方式的绑定 // 需要在前端增加其他的服务进行转换 err = http.ListenAndServe(bind, nil) return }() } return } // 客户端,新建一个连接 func Dial(cf *config.Config, scheme string, addr string, path string, hash string) (conn.Connect, error) { u := url.URL{Scheme: scheme, Host: addr, Path: path} // 没有加密的情况 if hash == "" { conn, _, err := (&websocket.Dialer{ HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)), }).Dial(u.String(), nil) if err != nil { return nil, err } ws := &Ws2{ cf: cf, conn: conn, } return ws, nil } i := strings.Index(hash, ":") if i <= 0 { return nil, errors.New("hash is invalid") } encryptMethod := hash[0:i] encryptKey := hash[i+1:] ci, ok := util.CipherMethod[encryptMethod] if !ok { return nil, errors.New("Unsupported encryption method: " + encryptMethod) } conn, _, err := (&websocket.Dialer{ HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)), }).Dial(u.String(), nil) if err != nil { return nil, err } var eiv []byte var div []byte if ci.IvLen > 0 { // 客户端 IV eiv = make([]byte, ci.IvLen) _, err = rand.Read(eiv) if err != nil { log.Println("[ws2 Client rand.Read ERROR]", err) return nil, err } // 发送 IV conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond)) if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil { log.Println("[ws2 Client conn.Write ERROR]", err) return nil, err } // 读取 IV err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait))) if err != nil { log.Println("[ws2 Client SetReadDeadline ERROR]", err) return nil, err } _, div, err = conn.ReadMessage() if err != nil { log.Println("[ws2 Client ReadFull ERROR]", err) return nil, err } } cipher, err := util.NewCipher(ci, encryptKey, eiv, div) if err != nil { log.Println("[ws2 NewCipher ERROR]", err) return nil, err } ws := &Ws2{ cf: cf, conn: conn, cipher: cipher, } return ws, nil } // 发送数据到网络 // 如果有加密函数的话会直接修改源数据 func (c *Ws2) writeMessage(buf []byte) (err error) { if c.cipher != nil { c.cipher.Encrypt(buf, buf) } c.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.WriteWait))) return c.conn.WriteMessage(websocket.BinaryMessage, buf) } // 发送Auth信息 // 建立连接后第一个发送的消息 func (c *Ws2) WriteAuthInfo(channel string, auth []byte) (err error) { protoLen := len(PROTO) if protoLen > 0xFF { return errors.New("length of protocol over") } channelLen := len(channel) if channelLen > 0xFFFF { return errors.New("length of channel over") } // id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte) dlen := 2 + 1 + protoLen + 1 + 2 + channelLen + len(auth) start := 0 buf := make([]byte, dlen) binary.BigEndian.PutUint16(buf[start:start+2], config.ID_AUTH) start += 2 buf[start] = byte(protoLen) start++ copy(buf[start:start+protoLen], []byte(PROTO)) start += protoLen buf[start] = VERSION start++ binary.BigEndian.PutUint16(buf[start:start+2], uint16(channelLen)) start += 2 copy(buf[start:start+channelLen], []byte(channel)) start += channelLen copy(buf[start:], auth) return c.writeMessage(buf) } // 获取Auth信息 // id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte) func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("recovered from panic: %v", r) return } }() err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.ReadWait))) if err != nil { return } _, msg, err := c.conn.ReadMessage() if err != nil { return } msgLen := len(msg) if msgLen < 4 { err = errors.New("message length less than 4") return } // 将读出来的数据进行解密 if c.cipher != nil { c.cipher.Decrypt(msg, msg) } start := 0 id := binary.BigEndian.Uint16(msg[start : start+2]) if id != config.ID_AUTH { err = fmt.Errorf("wrong message id: %d", id) return } start += 2 protoLen := int(msg[start]) if protoLen < 2 { err = errors.New("wrong proto length") return } start++ proto = string(msg[start : start+protoLen]) if proto != PROTO { err = fmt.Errorf("wrong proto: %s", proto) return } start += protoLen version = msg[start] if version != VERSION { err = fmt.Errorf("require version %d, get version: %d", VERSION, version) return } start++ channelLen := int(binary.BigEndian.Uint16(msg[start : start+2])) if channelLen < 2 { err = errors.New("wrong channel length") return } start += 2 channel = string(msg[start : start+channelLen]) start += channelLen auth = msg[start:] return } // 发送请求数据包到网络 func (c *Ws2) WriteRequest(id uint16, cmd string, data []byte) error { // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则报错 cmdLen := len(cmd) if cmdLen > 0x7F { return errors.New("length of command more than 0x7F") } dlen := 2 + 1 + cmdLen + len(data) buf := make([]byte, dlen) // 申请内存 binary.BigEndian.PutUint16(buf[0:2], id) buf[2] = byte(cmdLen) copy(buf[3:], cmd) copy(buf[3+cmdLen:], data) return c.writeMessage(buf) } // 发送响应数据包到网络 // 网络格式:[id, stateCode, data] func (c *Ws2) WriteResponse(id uint16, state uint8, data []byte) error { dlen := 2 + 1 + len(data) buf := make([]byte, dlen) binary.BigEndian.PutUint16(buf[0:2], id) buf[2] = state | 0x80 copy(buf[3:], data) return c.writeMessage(buf) } // 发送ping包 func (c *Ws2) WritePing(id uint16) error { buf := make([]byte, 2) binary.BigEndian.PutUint16(buf[0:2], id) return c.writeMessage(buf) } // 获取信息 func (c *Ws2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) { err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline))) if err != nil { return } _, msg, err := c.conn.ReadMessage() if err != nil { return } msgLen := len(msg) if msgLen < 2 { err = errors.New("message length less than 2") return } // 将读出来的数据进行解密 if c.cipher != nil { c.cipher.Decrypt(msg, msg) } id = binary.BigEndian.Uint16(msg[0:2]) // ping信息 if msgLen == 2 { msgType = conn.PingMsg return } if id > config.ID_MAX { err = fmt.Errorf("wrong message id: %d", id) return } cmdx := msg[2] if (cmdx & 0x80) == 0 { // 请求包 msgType = conn.RequestMsg cmdLen := int(cmdx) cmd = string(msg[3 : cmdLen+3]) data = msg[cmdLen+3:] return } else { // 响应数据包 msgType = conn.ResponseMsg state = cmdx & 0x7F data = msg[3:] return } } // 获取远程的地址 func (c *Ws2) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // 获取本地的地址 func (c *Ws2) LocalAddr() net.Addr { return c.conn.LocalAddr() } func (c *Ws2) Close() error { return c.conn.Close() }