ws2.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. package ws2
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strings"
  12. "time"
  13. "git.me9.top/git/tinymq/config"
  14. "git.me9.top/git/tinymq/conn"
  15. "git.me9.top/git/tinymq/conn/util"
  16. "github.com/gorilla/websocket"
  17. )
  18. const PROTO string = "ws"
  19. const PROTO_STL string = "wss"
  20. const VERSION uint8 = 2
  21. type Ws2 struct {
  22. cf *config.Config
  23. conn *websocket.Conn
  24. cipher *util.Cipher // 记录当前的加解密类,可以保证在没有ssl的情况下数据安全
  25. compress bool // 是否启用压缩
  26. }
  27. var upgrader = websocket.Upgrader{} // use default options
  28. // websocket 服务
  29. // 如果有绑定参数,则进行绑定操作代码
  30. func Server(cf *config.Config, bind string, path string, hash string, fn conn.ServerConnectFunc) (err error) {
  31. var ci *util.CipherInfo
  32. var encryptKey string
  33. if hash != "" {
  34. i := strings.Index(hash, ":")
  35. if i <= 0 {
  36. return errors.New("hash is invalid")
  37. }
  38. encryptMethod := hash[0:i]
  39. encryptKey = hash[i+1:]
  40. if c, ok := util.CipherMethod[encryptMethod]; ok {
  41. ci = c
  42. } else {
  43. return errors.New("Unsupported encryption method: " + encryptMethod)
  44. }
  45. }
  46. http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
  47. conn, err := upgrader.Upgrade(w, r, nil)
  48. if err != nil {
  49. log.Println("[ws2 Server Upgrade ERROR]", err)
  50. return
  51. }
  52. if ci == nil {
  53. ws := &Ws2{
  54. cf: cf,
  55. conn: conn,
  56. }
  57. fn(ws)
  58. return
  59. }
  60. var eiv []byte
  61. var div []byte
  62. if ci.IvLen > 0 {
  63. // 服务端 IV
  64. eiv = make([]byte, ci.IvLen)
  65. _, err = rand.Read(eiv)
  66. if err != nil {
  67. log.Println("[ws2 Server rand.Read ERROR]", err)
  68. return
  69. }
  70. // 发送 IV
  71. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  72. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  73. log.Println("[ws2 Server conn.Write ERROR]", err)
  74. return
  75. }
  76. // 读取 IV
  77. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  78. if err != nil {
  79. log.Println("[ws2 Server SetReadDeadline ERROR]", err)
  80. return
  81. }
  82. _, div, err = conn.ReadMessage()
  83. if err != nil {
  84. log.Println("[ws2 Server ReadFull ERROR]", err)
  85. return
  86. }
  87. }
  88. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  89. if err != nil {
  90. log.Println("[ws2 NewCipher ERROR]", err)
  91. return
  92. }
  93. ws := &Ws2{
  94. cf: cf,
  95. conn: conn,
  96. cipher: cipher,
  97. }
  98. fn(ws)
  99. })
  100. if bind != "" {
  101. go func() (err error) {
  102. defer func() {
  103. if err != nil {
  104. log.Fatal(err)
  105. }
  106. }()
  107. log.Printf("Listening and serving Websocket on %s\n", bind)
  108. // 暂时使用全局的方式,后面有需求再修改
  109. // 而且还没有 https 方式的绑定
  110. // 需要在前端增加其他的服务进行转换
  111. err = http.ListenAndServe(bind, nil)
  112. return
  113. }()
  114. }
  115. return
  116. }
  117. // 客户端,新建一个连接
  118. func Dial(cf *config.Config, scheme string, addr string, path string, hash string) (conn.Connect, error) {
  119. u := url.URL{Scheme: scheme, Host: addr, Path: path}
  120. // 没有加密的情况
  121. if hash == "" {
  122. conn, _, err := (&websocket.Dialer{
  123. HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
  124. }).Dial(u.String(), nil)
  125. if err != nil {
  126. return nil, err
  127. }
  128. ws := &Ws2{
  129. cf: cf,
  130. conn: conn,
  131. }
  132. return ws, nil
  133. }
  134. i := strings.Index(hash, ":")
  135. if i <= 0 {
  136. return nil, errors.New("hash is invalid")
  137. }
  138. encryptMethod := hash[0:i]
  139. encryptKey := hash[i+1:]
  140. ci, ok := util.CipherMethod[encryptMethod]
  141. if !ok {
  142. return nil, errors.New("Unsupported encryption method: " + encryptMethod)
  143. }
  144. conn, _, err := (&websocket.Dialer{
  145. HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
  146. }).Dial(u.String(), nil)
  147. if err != nil {
  148. return nil, err
  149. }
  150. var eiv []byte
  151. var div []byte
  152. if ci.IvLen > 0 {
  153. // 客户端 IV
  154. eiv = make([]byte, ci.IvLen)
  155. _, err = rand.Read(eiv)
  156. if err != nil {
  157. log.Println("[ws2 Client rand.Read ERROR]", err)
  158. return nil, err
  159. }
  160. // 发送 IV
  161. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  162. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  163. log.Println("[ws2 Client conn.Write ERROR]", err)
  164. return nil, err
  165. }
  166. // 读取 IV
  167. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  168. if err != nil {
  169. log.Println("[ws2 Client SetReadDeadline ERROR]", err)
  170. return nil, err
  171. }
  172. _, div, err = conn.ReadMessage()
  173. if err != nil {
  174. log.Println("[ws2 Client ReadFull ERROR]", err)
  175. return nil, err
  176. }
  177. }
  178. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  179. if err != nil {
  180. log.Println("[ws2 NewCipher ERROR]", err)
  181. return nil, err
  182. }
  183. ws := &Ws2{
  184. cf: cf,
  185. conn: conn,
  186. cipher: cipher,
  187. compress: true,
  188. }
  189. return ws, nil
  190. }
  191. // 发送数据到网络
  192. // 如果有加密函数的话会直接修改源数据
  193. func (c *Ws2) writeMessage(buf []byte) (err error) {
  194. if c.cipher != nil {
  195. c.cipher.Encrypt(buf, buf)
  196. }
  197. c.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.WriteWait)))
  198. return c.conn.WriteMessage(websocket.BinaryMessage, buf)
  199. }
  200. // 发送Auth信息
  201. // 建立连接后第一个发送的消息
  202. func (c *Ws2) WriteAuthInfo(channel string, auth []byte) (err error) {
  203. protoLen := len(PROTO)
  204. if protoLen > 0xFF {
  205. return errors.New("length of protocol over")
  206. }
  207. channelLen := len(channel)
  208. if channelLen > 0xFFFF {
  209. return errors.New("length of channel over")
  210. }
  211. // id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
  212. dlen := 2 + 1 + protoLen + 1 + 1 + 2 + channelLen + len(auth)
  213. index := 0
  214. buf := make([]byte, dlen)
  215. binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH)
  216. index += 2
  217. buf[index] = byte(protoLen)
  218. index++
  219. copy(buf[index:index+protoLen], []byte(PROTO))
  220. index += protoLen
  221. buf[index] = VERSION
  222. index++
  223. if c.compress {
  224. buf[index] = 0x01
  225. } else {
  226. buf[index] = 0
  227. }
  228. index++
  229. binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
  230. index += 2
  231. copy(buf[index:index+channelLen], []byte(channel))
  232. index += channelLen
  233. copy(buf[index:], auth)
  234. return c.writeMessage(buf)
  235. }
  236. // 获取Auth信息
  237. // id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
  238. func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
  239. defer func() {
  240. if r := recover(); r != nil {
  241. err = fmt.Errorf("recovered from panic: %v", r)
  242. return
  243. }
  244. }()
  245. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.ReadWait)))
  246. if err != nil {
  247. return
  248. }
  249. _, msg, err := c.conn.ReadMessage()
  250. if err != nil {
  251. return
  252. }
  253. msgLen := len(msg)
  254. if msgLen < 9 {
  255. err = errors.New("wrong message length")
  256. return
  257. }
  258. // 将读出来的数据进行解密
  259. if c.cipher != nil {
  260. c.cipher.Decrypt(msg, msg)
  261. }
  262. index := 0
  263. id := binary.BigEndian.Uint16(msg[index : index+2])
  264. if id != config.ID_AUTH {
  265. err = fmt.Errorf("wrong message id: %d", id)
  266. return
  267. }
  268. index += 2
  269. protoLen := int(msg[index])
  270. if protoLen < 2 {
  271. err = errors.New("wrong proto length")
  272. return
  273. }
  274. index++
  275. proto = string(msg[index : index+protoLen])
  276. if proto != PROTO {
  277. err = fmt.Errorf("wrong proto: %s", proto)
  278. return
  279. }
  280. index += protoLen
  281. version = msg[index]
  282. if version != VERSION {
  283. err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
  284. return
  285. }
  286. index++
  287. c.compress = (msg[index] & 0x01) != 0
  288. index++
  289. channelLen := int(binary.BigEndian.Uint16(msg[index : index+2]))
  290. if channelLen < 2 {
  291. err = errors.New("wrong channel length")
  292. return
  293. }
  294. index += 2
  295. channel = string(msg[index : index+channelLen])
  296. index += channelLen
  297. auth = msg[index:]
  298. return
  299. }
  300. // 发送请求数据包到网络
  301. func (c *Ws2) WriteRequest(id uint16, cmd string, data []byte) error {
  302. // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则报错
  303. cmdLen := len(cmd)
  304. if cmdLen > 0x7F {
  305. return errors.New("length of command more than 0x7F")
  306. }
  307. dlen := len(data)
  308. if c.compress && dlen > 0 {
  309. compressedData, ok := util.CompressData(data)
  310. ddlen := 2 + 1 + cmdLen + 1 + len(compressedData)
  311. buf := make([]byte, ddlen) // 申请内存
  312. index := 0
  313. binary.BigEndian.PutUint16(buf[index:index+2], id)
  314. index += 2
  315. buf[index] = byte(cmdLen)
  316. index++
  317. copy(buf[index:], cmd)
  318. index += cmdLen
  319. if ok {
  320. buf[index] = 0x01
  321. } else {
  322. buf[index] = 0
  323. }
  324. index++
  325. copy(buf[index:], compressedData)
  326. return c.writeMessage(buf)
  327. } else {
  328. ddlen := 2 + 1 + cmdLen + dlen
  329. buf := make([]byte, ddlen) // 申请内存
  330. binary.BigEndian.PutUint16(buf[0:2], id)
  331. buf[2] = byte(cmdLen)
  332. copy(buf[3:], cmd)
  333. copy(buf[3+cmdLen:], data)
  334. return c.writeMessage(buf)
  335. }
  336. }
  337. // 发送响应数据包到网络
  338. // 网络格式:[id, stateCode, data]
  339. func (c *Ws2) WriteResponse(id uint16, state uint8, data []byte) error {
  340. dlen := len(data)
  341. if c.compress && dlen > 0 {
  342. compressedData, ok := util.CompressData(data)
  343. ddlen := 2 + 1 + 1 + len(compressedData)
  344. buf := make([]byte, ddlen)
  345. binary.BigEndian.PutUint16(buf[0:2], id)
  346. buf[2] = state | 0x80
  347. if ok {
  348. buf[3] = 0x01
  349. } else {
  350. buf[3] = 0
  351. }
  352. copy(buf[4:], compressedData)
  353. return c.writeMessage(buf)
  354. } else {
  355. ddlen := 2 + 1 + dlen
  356. buf := make([]byte, ddlen)
  357. binary.BigEndian.PutUint16(buf[0:2], id)
  358. buf[2] = state | 0x80
  359. copy(buf[3:], data)
  360. return c.writeMessage(buf)
  361. }
  362. }
  363. // 发送ping包
  364. func (c *Ws2) WritePing(id uint16) error {
  365. buf := make([]byte, 2)
  366. binary.BigEndian.PutUint16(buf[0:2], id)
  367. return c.writeMessage(buf)
  368. }
  369. // 获取信息
  370. func (c *Ws2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
  371. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
  372. if err != nil {
  373. return
  374. }
  375. _, msg, err := c.conn.ReadMessage()
  376. if err != nil {
  377. return
  378. }
  379. msgLen := len(msg)
  380. if msgLen < 2 {
  381. err = errors.New("message length less than 2")
  382. return
  383. }
  384. // 将读出来的数据进行解密
  385. if c.cipher != nil {
  386. c.cipher.Decrypt(msg, msg)
  387. }
  388. id = binary.BigEndian.Uint16(msg[0:2])
  389. // ping信息
  390. if msgLen == 2 {
  391. msgType = conn.PingMsg
  392. return
  393. }
  394. if id > config.ID_MAX {
  395. err = fmt.Errorf("wrong message id: %d", id)
  396. return
  397. }
  398. cmdx := msg[2]
  399. if (cmdx & 0x80) == 0 {
  400. // 请求包
  401. msgType = conn.RequestMsg
  402. cmdLen := int(cmdx)
  403. cmd = string(msg[3 : cmdLen+3])
  404. data = msg[cmdLen+3:]
  405. } else {
  406. // 响应数据包
  407. msgType = conn.ResponseMsg
  408. state = cmdx & 0x7F
  409. data = msg[3:]
  410. }
  411. if c.compress && len(data) > 1 {
  412. data, _ = util.UncompressData(data)
  413. }
  414. return
  415. }
  416. // 获取远程的地址
  417. func (c *Ws2) RemoteAddr() net.Addr {
  418. return c.conn.RemoteAddr()
  419. }
  420. // 获取本地的地址
  421. func (c *Ws2) LocalAddr() net.Addr {
  422. return c.conn.LocalAddr()
  423. }
  424. func (c *Ws2) Close() error {
  425. return c.conn.Close()
  426. }