ws2.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. package ws2
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strings"
  12. "time"
  13. "git.me9.top/git/tinymq/config"
  14. "git.me9.top/git/tinymq/conn"
  15. "git.me9.top/git/tinymq/conn/util"
  16. "github.com/gorilla/websocket"
  17. )
  18. const PROTO string = "ws"
  19. const PROTO_STL string = "wss"
  20. const VERSION uint8 = 2
  21. type Ws2 struct {
  22. cf *config.Config
  23. conn *websocket.Conn
  24. cipher *util.Cipher // 记录当前的加解密类,可以保证在没有ssl的情况下数据安全
  25. compress bool // 是否启用压缩
  26. }
  27. var upgrader = websocket.Upgrader{
  28. CheckOrigin: func(r *http.Request) bool {
  29. return true // 允许任何Origin跨站访问
  30. },
  31. } // use default options
  32. // websocket 服务
  33. // 如果有绑定参数,则进行绑定操作代码
  34. func Server(cf *config.Config, bind string, path string, hash string, fn conn.ServerConnectFunc) (err error) {
  35. var ci *util.CipherInfo
  36. var encryptKey string
  37. if hash != "" {
  38. i := strings.Index(hash, ":")
  39. if i <= 0 {
  40. return errors.New("hash is invalid")
  41. }
  42. encryptMethod := hash[0:i]
  43. encryptKey = hash[i+1:]
  44. if c, ok := util.CipherMethod[encryptMethod]; ok {
  45. ci = c
  46. } else {
  47. return errors.New("Unsupported encryption method: " + encryptMethod)
  48. }
  49. }
  50. http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
  51. conn, err := upgrader.Upgrade(w, r, nil)
  52. if err != nil {
  53. log.Println("[ws2 Server Upgrade ERROR]", err)
  54. return
  55. }
  56. if ci == nil {
  57. ws := &Ws2{
  58. cf: cf,
  59. conn: conn,
  60. }
  61. fn(ws)
  62. return
  63. }
  64. var eiv []byte
  65. var div []byte
  66. if ci.IvLen > 0 {
  67. // 服务端 IV
  68. eiv = make([]byte, ci.IvLen)
  69. _, err = rand.Read(eiv)
  70. if err != nil {
  71. log.Println("[ws2 Server rand.Read ERROR]", err)
  72. return
  73. }
  74. // 发送 IV
  75. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  76. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  77. log.Println("[ws2 Server conn.Write ERROR]", err)
  78. return
  79. }
  80. // 读取 IV
  81. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  82. if err != nil {
  83. log.Println("[ws2 Server SetReadDeadline ERROR]", err)
  84. return
  85. }
  86. _, div, err = conn.ReadMessage()
  87. if err != nil {
  88. log.Println("[ws2 Server ReadFull ERROR]", err)
  89. conn.Close()
  90. return
  91. }
  92. }
  93. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  94. if err != nil {
  95. log.Println("[ws2 NewCipher ERROR]", err)
  96. return
  97. }
  98. ws := &Ws2{
  99. cf: cf,
  100. conn: conn,
  101. cipher: cipher,
  102. }
  103. fn(ws)
  104. })
  105. if bind != "" {
  106. go func() (err error) {
  107. defer func() {
  108. if err != nil {
  109. log.Fatal(err)
  110. }
  111. }()
  112. log.Printf("Listening and serving Websocket on %s\n", bind)
  113. // 暂时使用全局的方式,后面有需求再修改
  114. // 而且还没有 https 方式的绑定
  115. // 需要在前端增加其他的服务进行转换
  116. err = http.ListenAndServe(bind, nil)
  117. return
  118. }()
  119. }
  120. return
  121. }
  122. // 客户端,新建一个连接
  123. func Dial(cf *config.Config, scheme string, addr string, path string, hash string) (conn.Connect, error) {
  124. u := url.URL{Scheme: scheme, Host: addr, Path: path}
  125. // 没有加密的情况
  126. if hash == "" {
  127. conn, _, err := (&websocket.Dialer{
  128. HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
  129. }).Dial(u.String(), nil)
  130. if err != nil {
  131. return nil, err
  132. }
  133. ws := &Ws2{
  134. cf: cf,
  135. conn: conn,
  136. }
  137. return ws, nil
  138. }
  139. i := strings.Index(hash, ":")
  140. if i <= 0 {
  141. return nil, errors.New("hash is invalid")
  142. }
  143. encryptMethod := hash[0:i]
  144. encryptKey := hash[i+1:]
  145. ci, ok := util.CipherMethod[encryptMethod]
  146. if !ok {
  147. return nil, errors.New("Unsupported encryption method: " + encryptMethod)
  148. }
  149. conn, _, err := (&websocket.Dialer{
  150. HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
  151. }).Dial(u.String(), nil)
  152. if err != nil {
  153. return nil, err
  154. }
  155. var eiv []byte
  156. var div []byte
  157. if ci.IvLen > 0 {
  158. // 客户端 IV
  159. eiv = make([]byte, ci.IvLen)
  160. _, err = rand.Read(eiv)
  161. if err != nil {
  162. log.Println("[ws2 Client rand.Read ERROR]", err)
  163. return nil, err
  164. }
  165. // 发送 IV
  166. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  167. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  168. log.Println("[ws2 Client conn.Write ERROR]", err)
  169. return nil, err
  170. }
  171. // 读取 IV
  172. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  173. if err != nil {
  174. log.Println("[ws2 Client SetReadDeadline ERROR]", err)
  175. return nil, err
  176. }
  177. _, div, err = conn.ReadMessage()
  178. if err != nil {
  179. log.Println("[ws2 Client ReadFull ERROR]", err)
  180. return nil, err
  181. }
  182. }
  183. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  184. if err != nil {
  185. log.Println("[ws2 NewCipher ERROR]", err)
  186. return nil, err
  187. }
  188. ws := &Ws2{
  189. cf: cf,
  190. conn: conn,
  191. cipher: cipher,
  192. compress: true,
  193. }
  194. return ws, nil
  195. }
  196. // 发送数据到网络
  197. // 如果有加密函数的话会直接修改源数据
  198. func (c *Ws2) writeMessage(buf []byte) (err error) {
  199. if c.cipher != nil {
  200. c.cipher.Encrypt(buf, buf)
  201. }
  202. c.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.WriteWait)))
  203. return c.conn.WriteMessage(websocket.BinaryMessage, buf)
  204. }
  205. // 发送Auth信息
  206. // 建立连接后第一个发送的消息
  207. func (c *Ws2) WriteAuthInfo(channel string, auth []byte) (err error) {
  208. protoLen := len(PROTO)
  209. if protoLen > 0xFF {
  210. return errors.New("length of protocol over")
  211. }
  212. channelLen := len(channel)
  213. if channelLen > 0xFFFF {
  214. return errors.New("length of channel over")
  215. }
  216. // id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
  217. dlen := 2 + 1 + protoLen + 1 + 1 + 2 + channelLen + len(auth)
  218. index := 0
  219. buf := make([]byte, dlen)
  220. binary.BigEndian.PutUint16(buf[index:index+2], config.ID_AUTH)
  221. index += 2
  222. buf[index] = byte(protoLen)
  223. index++
  224. copy(buf[index:index+protoLen], []byte(PROTO))
  225. index += protoLen
  226. buf[index] = VERSION
  227. index++
  228. if c.compress {
  229. buf[index] = 0x01
  230. } else {
  231. buf[index] = 0
  232. }
  233. index++
  234. binary.BigEndian.PutUint16(buf[index:index+2], uint16(channelLen))
  235. index += 2
  236. copy(buf[index:index+channelLen], []byte(channel))
  237. index += channelLen
  238. copy(buf[index:], auth)
  239. return c.writeMessage(buf)
  240. }
  241. // 获取Auth信息
  242. // id(65502)+proto(string)+version(uint8)+option(byte)+channel(string)+auth([]byte)
  243. func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
  244. defer func() {
  245. if r := recover(); r != nil {
  246. err = fmt.Errorf("recovered from panic: %v", r)
  247. return
  248. }
  249. }()
  250. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.ReadWait)))
  251. if err != nil {
  252. return
  253. }
  254. _, msg, err := c.conn.ReadMessage()
  255. if err != nil {
  256. return
  257. }
  258. msgLen := len(msg)
  259. if msgLen < 9 {
  260. err = errors.New("wrong message length")
  261. return
  262. }
  263. // 将读出来的数据进行解密
  264. if c.cipher != nil {
  265. c.cipher.Decrypt(msg, msg)
  266. }
  267. index := 0
  268. id := binary.BigEndian.Uint16(msg[index : index+2])
  269. if id != config.ID_AUTH {
  270. err = fmt.Errorf("wrong message id: %d", id)
  271. return
  272. }
  273. index += 2
  274. protoLen := int(msg[index])
  275. if protoLen < 2 {
  276. err = errors.New("wrong proto length")
  277. return
  278. }
  279. index++
  280. proto = string(msg[index : index+protoLen])
  281. if proto != PROTO {
  282. err = fmt.Errorf("wrong proto: %s", proto)
  283. return
  284. }
  285. index += protoLen
  286. version = msg[index]
  287. if version != VERSION {
  288. err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
  289. return
  290. }
  291. index++
  292. c.compress = (msg[index] & 0x01) != 0
  293. index++
  294. channelLen := int(binary.BigEndian.Uint16(msg[index : index+2]))
  295. if channelLen < 2 {
  296. err = errors.New("wrong channel length")
  297. return
  298. }
  299. index += 2
  300. channel = string(msg[index : index+channelLen])
  301. index += channelLen
  302. auth = msg[index:]
  303. return
  304. }
  305. // 发送请求数据包到网络
  306. func (c *Ws2) WriteRequest(id uint16, cmd string, data []byte) error {
  307. // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则报错
  308. cmdLen := len(cmd)
  309. if cmdLen > 0x7F {
  310. return errors.New("length of command more than 0x7F")
  311. }
  312. dlen := len(data)
  313. if c.compress && dlen > 0 {
  314. compressedData, ok := util.CompressData(data)
  315. ddlen := 2 + 1 + cmdLen + 1 + len(compressedData)
  316. buf := make([]byte, ddlen) // 申请内存
  317. index := 0
  318. binary.BigEndian.PutUint16(buf[index:index+2], id)
  319. index += 2
  320. buf[index] = byte(cmdLen)
  321. index++
  322. copy(buf[index:], cmd)
  323. index += cmdLen
  324. if ok {
  325. buf[index] = 0x01
  326. if c.cf.PrintMsg {
  327. log.Println("[CompressData]", len(data), "->", len(compressedData))
  328. }
  329. } else {
  330. buf[index] = 0
  331. }
  332. index++
  333. copy(buf[index:], compressedData)
  334. return c.writeMessage(buf)
  335. } else {
  336. ddlen := 2 + 1 + cmdLen + dlen
  337. buf := make([]byte, ddlen) // 申请内存
  338. binary.BigEndian.PutUint16(buf[0:2], id)
  339. buf[2] = byte(cmdLen)
  340. copy(buf[3:], cmd)
  341. copy(buf[3+cmdLen:], data)
  342. return c.writeMessage(buf)
  343. }
  344. }
  345. // 发送响应数据包到网络
  346. // 网络格式:[id, stateCode, data]
  347. func (c *Ws2) WriteResponse(id uint16, state uint8, data []byte) error {
  348. dlen := len(data)
  349. if c.compress && dlen > 0 {
  350. compressedData, ok := util.CompressData(data)
  351. ddlen := 2 + 1 + 1 + len(compressedData)
  352. buf := make([]byte, ddlen)
  353. binary.BigEndian.PutUint16(buf[0:2], id)
  354. buf[2] = state | 0x80
  355. if ok {
  356. buf[3] = 0x01
  357. if c.cf.PrintMsg {
  358. log.Println("[CompressData]", len(data), "->", len(compressedData))
  359. }
  360. } else {
  361. buf[3] = 0
  362. }
  363. copy(buf[4:], compressedData)
  364. return c.writeMessage(buf)
  365. } else {
  366. ddlen := 2 + 1 + dlen
  367. buf := make([]byte, ddlen)
  368. binary.BigEndian.PutUint16(buf[0:2], id)
  369. buf[2] = state | 0x80
  370. copy(buf[3:], data)
  371. return c.writeMessage(buf)
  372. }
  373. }
  374. // 发送ping包
  375. func (c *Ws2) WritePing(id uint16) error {
  376. buf := make([]byte, 2)
  377. binary.BigEndian.PutUint16(buf[0:2], id)
  378. return c.writeMessage(buf)
  379. }
  380. // 获取信息
  381. func (c *Ws2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
  382. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
  383. if err != nil {
  384. return
  385. }
  386. _, msg, err := c.conn.ReadMessage()
  387. if err != nil {
  388. return
  389. }
  390. msgLen := len(msg)
  391. if msgLen < 2 {
  392. err = errors.New("message length less than 2")
  393. return
  394. }
  395. // 将读出来的数据进行解密
  396. if c.cipher != nil {
  397. c.cipher.Decrypt(msg, msg)
  398. }
  399. id = binary.BigEndian.Uint16(msg[0:2])
  400. // ping信息
  401. if msgLen == 2 {
  402. msgType = conn.PingMsg
  403. return
  404. }
  405. if id > config.ID_MAX {
  406. err = fmt.Errorf("wrong message id: %d", id)
  407. return
  408. }
  409. cmdx := msg[2]
  410. if (cmdx & 0x80) == 0 {
  411. // 请求包
  412. msgType = conn.RequestMsg
  413. cmdLen := int(cmdx)
  414. cmd = string(msg[3 : cmdLen+3])
  415. data = msg[cmdLen+3:]
  416. } else {
  417. // 响应数据包
  418. msgType = conn.ResponseMsg
  419. state = cmdx & 0x7F
  420. data = msg[3:]
  421. }
  422. if c.compress && len(data) > 1 {
  423. data, _ = util.UncompressData(data)
  424. }
  425. return
  426. }
  427. // 获取远程的地址
  428. func (c *Ws2) RemoteAddr() net.Addr {
  429. return c.conn.RemoteAddr()
  430. }
  431. // 获取本地的地址
  432. func (c *Ws2) LocalAddr() net.Addr {
  433. return c.conn.LocalAddr()
  434. }
  435. func (c *Ws2) Close() error {
  436. return c.conn.Close()
  437. }