wsv2.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. package wsv2
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strings"
  12. "time"
  13. "git.me9.top/git/tinymq/config"
  14. "git.me9.top/git/tinymq/conn"
  15. "git.me9.top/git/tinymq/conn/util"
  16. "github.com/gorilla/websocket"
  17. )
  18. const VERSION uint8 = 2
  19. const PROTO string = "ws"
  20. type WsConnectV2 struct {
  21. cf *config.Config
  22. conn *websocket.Conn
  23. cipher *util.Cipher // 记录当前的加解密类,可以保证在没有ssl的情况下数据安全
  24. }
  25. var upgrader = websocket.Upgrader{} // use default options
  26. // websocket 服务
  27. // 如果有绑定参数,则进行绑定操作代码
  28. func Server(cf *config.Config, bind string, path string, hash string, fn conn.ServerConnectFunc) (err error) {
  29. var ci *util.CipherInfo
  30. var encryptKey string
  31. if hash != "" {
  32. i := strings.Index(hash, ":")
  33. if i <= 0 {
  34. return errors.New("hash is invalid")
  35. }
  36. encryptMethod := hash[0:i]
  37. encryptKey = hash[i+1:]
  38. if c, ok := util.CipherMethod[encryptMethod]; ok {
  39. ci = c
  40. } else {
  41. return errors.New("Unsupported encryption method: " + encryptMethod)
  42. }
  43. }
  44. http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
  45. conn, err := upgrader.Upgrade(w, r, nil)
  46. if err != nil {
  47. log.Println("[wsv2 Server Upgrade ERROR]", err)
  48. return
  49. }
  50. if ci == nil {
  51. ws := &WsConnectV2{
  52. cf: cf,
  53. conn: conn,
  54. }
  55. fn(ws)
  56. return
  57. }
  58. var eiv []byte
  59. var div []byte
  60. if ci.IvLen > 0 {
  61. // 服务端 IV
  62. eiv = make([]byte, ci.IvLen)
  63. _, err = rand.Read(eiv)
  64. if err != nil {
  65. log.Println("[wsv2 Server rand.Read ERROR]", err)
  66. return
  67. }
  68. // 发送 IV
  69. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  70. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  71. log.Println("[wsv2 Server conn.Write ERROR]", err)
  72. return
  73. }
  74. // 读取 IV
  75. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  76. if err != nil {
  77. log.Println("[wsv2 Server SetReadDeadline ERROR]", err)
  78. return
  79. }
  80. _, div, err = conn.ReadMessage()
  81. if err != nil {
  82. log.Println("[wsv2 Server ReadFull ERROR]", err)
  83. return
  84. }
  85. }
  86. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  87. if err != nil {
  88. log.Println("[wsv2 NewCipher ERROR]", err)
  89. return
  90. }
  91. ws := &WsConnectV2{
  92. cf: cf,
  93. conn: conn,
  94. cipher: cipher,
  95. }
  96. fn(ws)
  97. })
  98. if bind != "" {
  99. go func() (err error) {
  100. defer func() {
  101. if err != nil {
  102. log.Fatal(err)
  103. }
  104. }()
  105. log.Printf("Listening and serving Websocket on %s\n", bind)
  106. // 暂时使用全局的方式,后面有需求再修改
  107. // 而且还没有 https 方式的绑定
  108. // 需要在前端增加其他的服务进行转换
  109. err = http.ListenAndServe(bind, nil)
  110. return
  111. }()
  112. }
  113. return
  114. }
  115. // 客户端,新建一个连接
  116. func Client(cf *config.Config, addr string, path string, hash string) (conn.Connect, error) {
  117. u := url.URL{Scheme: "ws", Host: addr, Path: path}
  118. // 没有加密的情况
  119. if hash == "" {
  120. conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
  121. if err != nil {
  122. return nil, err
  123. }
  124. ws := &WsConnectV2{
  125. cf: cf,
  126. conn: conn,
  127. }
  128. return ws, nil
  129. }
  130. i := strings.Index(hash, ":")
  131. if i <= 0 {
  132. return nil, errors.New("hash is invalid")
  133. }
  134. encryptMethod := hash[0:i]
  135. encryptKey := hash[i+1:]
  136. ci, ok := util.CipherMethod[encryptMethod]
  137. if !ok {
  138. return nil, errors.New("Unsupported encryption method: " + encryptMethod)
  139. }
  140. conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
  141. if err != nil {
  142. return nil, err
  143. }
  144. var eiv []byte
  145. var div []byte
  146. if ci.IvLen > 0 {
  147. // 客户端 IV
  148. eiv = make([]byte, ci.IvLen)
  149. _, err = rand.Read(eiv)
  150. if err != nil {
  151. log.Println("[wsv2 Client rand.Read ERROR]", err)
  152. return nil, err
  153. }
  154. // 发送 IV
  155. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  156. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  157. log.Println("[wsv2 Client conn.Write ERROR]", err)
  158. return nil, err
  159. }
  160. // 读取 IV
  161. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  162. if err != nil {
  163. log.Println("[wsv2 Client SetReadDeadline ERROR]", err)
  164. return nil, err
  165. }
  166. _, div, err = conn.ReadMessage()
  167. if err != nil {
  168. log.Println("[wsv2 Client ReadFull ERROR]", err)
  169. return nil, err
  170. }
  171. }
  172. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  173. if err != nil {
  174. log.Println("[wsv2 NewCipher ERROR]", err)
  175. return nil, err
  176. }
  177. ws := &WsConnectV2{
  178. cf: cf,
  179. conn: conn,
  180. cipher: cipher,
  181. }
  182. return ws, nil
  183. }
  184. // 发送数据到网络
  185. // 如果有加密函数的话会直接修改源数据
  186. func (c *WsConnectV2) writeMessage(buf []byte) (err error) {
  187. if c.cipher != nil {
  188. c.cipher.Encrypt(buf, buf)
  189. }
  190. c.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.WriteWait)))
  191. return c.conn.WriteMessage(websocket.BinaryMessage, buf)
  192. }
  193. // 发送Auth信息
  194. // 建立连接后第一个发送的消息
  195. func (c *WsConnectV2) WriteAuthInfo(channel string, auth []byte) (err error) {
  196. protoLen := len(PROTO)
  197. channelLen := len(channel)
  198. if channelLen > 0xFFFF {
  199. return errors.New("length of channel over")
  200. }
  201. dlen := 2 + 1 + 1 + protoLen + 2 + channelLen + len(auth)
  202. start := 0
  203. buf := make([]byte, dlen)
  204. binary.BigEndian.PutUint16(buf[start:start+2], config.ID_AUTH)
  205. start += 2
  206. buf[start] = VERSION
  207. start++
  208. buf[start] = byte(protoLen)
  209. start++
  210. copy(buf[start:start+protoLen], []byte(PROTO))
  211. start += protoLen
  212. binary.BigEndian.PutUint16(buf[start:start+2], uint16(channelLen))
  213. start += 2
  214. copy(buf[start:start+channelLen], []byte(channel))
  215. start += channelLen
  216. copy(buf[start:], auth)
  217. return c.writeMessage(buf)
  218. }
  219. // 获取Auth信息
  220. // id(uint16)+version(uint8)+proto(string)+channel(string)+auth([]byte)
  221. func (c *WsConnectV2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
  222. defer func() {
  223. if r := recover(); r != nil {
  224. err = fmt.Errorf("recovered from panic: %v", r)
  225. return
  226. }
  227. }()
  228. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.ReadWait)))
  229. if err != nil {
  230. return
  231. }
  232. _, msg, err := c.conn.ReadMessage()
  233. if err != nil {
  234. return
  235. }
  236. msgLen := len(msg)
  237. if msgLen < 4 {
  238. err = errors.New("message length less than 4")
  239. return
  240. }
  241. // 将读出来的数据进行解密
  242. if c.cipher != nil {
  243. c.cipher.Decrypt(msg, msg)
  244. }
  245. start := 0
  246. id := binary.BigEndian.Uint16(msg[start : start+2])
  247. if id != config.ID_AUTH {
  248. err = fmt.Errorf("wrong message id: %d", id)
  249. return
  250. }
  251. start += 2
  252. version = msg[start]
  253. if version != VERSION {
  254. err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
  255. return
  256. }
  257. start++
  258. protoLen := int(msg[start])
  259. if protoLen < 2 {
  260. err = errors.New("wrong proto length")
  261. return
  262. }
  263. start++
  264. proto = string(msg[start : start+protoLen])
  265. if proto != PROTO {
  266. err = fmt.Errorf("wrong proto: %s", proto)
  267. return
  268. }
  269. start += protoLen
  270. channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
  271. if channelLen < 2 {
  272. err = errors.New("wrong channel length")
  273. return
  274. }
  275. start += 2
  276. channel = string(msg[start : start+channelLen])
  277. start += channelLen
  278. auth = msg[start:]
  279. return
  280. }
  281. // 发送请求数据包到网络
  282. func (c *WsConnectV2) WriteRequest(id uint16, cmd string, data []byte) error {
  283. // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则报错
  284. cmdLen := len(cmd)
  285. if cmdLen > 0x7F {
  286. return errors.New("length of command more than 0x7F")
  287. }
  288. dlen := 2 + 1 + cmdLen + len(data)
  289. buf := make([]byte, dlen) // 申请内存
  290. binary.BigEndian.PutUint16(buf[0:2], id)
  291. buf[2] = byte(cmdLen)
  292. copy(buf[3:], cmd)
  293. copy(buf[3+cmdLen:], data)
  294. return c.writeMessage(buf)
  295. }
  296. // 发送响应数据包到网络
  297. // 网络格式:[id, stateCode, data]
  298. func (c *WsConnectV2) WriteResponse(id uint16, state uint8, data []byte) error {
  299. dlen := 2 + 1 + len(data)
  300. buf := make([]byte, dlen)
  301. binary.BigEndian.PutUint16(buf[0:2], id)
  302. buf[2] = state | 0x80
  303. copy(buf[3:], data)
  304. return c.writeMessage(buf)
  305. }
  306. // 发送ping包
  307. func (c *WsConnectV2) WritePing(id uint16) error {
  308. buf := make([]byte, 2)
  309. binary.BigEndian.PutUint16(buf[0:2], id)
  310. return c.writeMessage(buf)
  311. }
  312. // 获取信息
  313. func (c *WsConnectV2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
  314. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
  315. if err != nil {
  316. return
  317. }
  318. _, msg, err := c.conn.ReadMessage()
  319. if err != nil {
  320. return
  321. }
  322. msgLen := len(msg)
  323. if msgLen < 2 {
  324. err = errors.New("message length less than 2")
  325. return
  326. }
  327. // 将读出来的数据进行解密
  328. if c.cipher != nil {
  329. c.cipher.Decrypt(msg, msg)
  330. }
  331. id = binary.BigEndian.Uint16(msg[0:2])
  332. // ping信息
  333. if msgLen == 2 {
  334. msgType = conn.PingMsg
  335. return
  336. }
  337. if id > config.ID_MAX {
  338. err = fmt.Errorf("wrong message id: %d", id)
  339. return
  340. }
  341. cmdx := msg[2]
  342. if (cmdx & 0x80) == 0 {
  343. // 请求包
  344. msgType = conn.RequestMsg
  345. cmdLen := int(cmdx)
  346. cmd = string(msg[3 : cmdLen+3])
  347. data = msg[cmdLen+3:]
  348. return
  349. } else {
  350. // 响应数据包
  351. msgType = conn.ResponseMsg
  352. state = cmdx & 0x7F
  353. data = msg[3:]
  354. return
  355. }
  356. }
  357. // 获取远程的地址
  358. func (c *WsConnectV2) RemoteAddr() net.Addr {
  359. return c.conn.RemoteAddr()
  360. }
  361. // 获取本地的地址
  362. func (c *WsConnectV2) LocalAddr() net.Addr {
  363. return c.conn.LocalAddr()
  364. }
  365. func (c *WsConnectV2) Close() error {
  366. return c.conn.Close()
  367. }