ws2.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. package ws2
  2. import (
  3. "crypto/rand"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strings"
  12. "time"
  13. "git.me9.top/git/tinymq/config"
  14. "git.me9.top/git/tinymq/conn"
  15. "git.me9.top/git/tinymq/conn/util"
  16. "github.com/gorilla/websocket"
  17. )
  18. const PROTO string = "ws"
  19. const PROTO_STL string = "wss"
  20. const VERSION uint8 = 2
  21. type Ws2 struct {
  22. cf *config.Config
  23. conn *websocket.Conn
  24. cipher *util.Cipher // 记录当前的加解密类,可以保证在没有ssl的情况下数据安全
  25. }
  26. var upgrader = websocket.Upgrader{} // use default options
  27. // websocket 服务
  28. // 如果有绑定参数,则进行绑定操作代码
  29. func Server(cf *config.Config, bind string, path string, hash string, fn conn.ServerConnectFunc) (err error) {
  30. var ci *util.CipherInfo
  31. var encryptKey string
  32. if hash != "" {
  33. i := strings.Index(hash, ":")
  34. if i <= 0 {
  35. return errors.New("hash is invalid")
  36. }
  37. encryptMethod := hash[0:i]
  38. encryptKey = hash[i+1:]
  39. if c, ok := util.CipherMethod[encryptMethod]; ok {
  40. ci = c
  41. } else {
  42. return errors.New("Unsupported encryption method: " + encryptMethod)
  43. }
  44. }
  45. http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
  46. conn, err := upgrader.Upgrade(w, r, nil)
  47. if err != nil {
  48. log.Println("[ws2 Server Upgrade ERROR]", err)
  49. return
  50. }
  51. if ci == nil {
  52. ws := &Ws2{
  53. cf: cf,
  54. conn: conn,
  55. }
  56. fn(ws)
  57. return
  58. }
  59. var eiv []byte
  60. var div []byte
  61. if ci.IvLen > 0 {
  62. // 服务端 IV
  63. eiv = make([]byte, ci.IvLen)
  64. _, err = rand.Read(eiv)
  65. if err != nil {
  66. log.Println("[ws2 Server rand.Read ERROR]", err)
  67. return
  68. }
  69. // 发送 IV
  70. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  71. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  72. log.Println("[ws2 Server conn.Write ERROR]", err)
  73. return
  74. }
  75. // 读取 IV
  76. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  77. if err != nil {
  78. log.Println("[ws2 Server SetReadDeadline ERROR]", err)
  79. return
  80. }
  81. _, div, err = conn.ReadMessage()
  82. if err != nil {
  83. log.Println("[ws2 Server ReadFull ERROR]", err)
  84. return
  85. }
  86. }
  87. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  88. if err != nil {
  89. log.Println("[ws2 NewCipher ERROR]", err)
  90. return
  91. }
  92. ws := &Ws2{
  93. cf: cf,
  94. conn: conn,
  95. cipher: cipher,
  96. }
  97. fn(ws)
  98. })
  99. if bind != "" {
  100. go func() (err error) {
  101. defer func() {
  102. if err != nil {
  103. log.Fatal(err)
  104. }
  105. }()
  106. log.Printf("Listening and serving Websocket on %s\n", bind)
  107. // 暂时使用全局的方式,后面有需求再修改
  108. // 而且还没有 https 方式的绑定
  109. // 需要在前端增加其他的服务进行转换
  110. err = http.ListenAndServe(bind, nil)
  111. return
  112. }()
  113. }
  114. return
  115. }
  116. // 客户端,新建一个连接
  117. func Dial(cf *config.Config, scheme string, addr string, path string, hash string) (conn.Connect, error) {
  118. u := url.URL{Scheme: scheme, Host: addr, Path: path}
  119. // 没有加密的情况
  120. if hash == "" {
  121. conn, _, err := (&websocket.Dialer{
  122. HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
  123. }).Dial(u.String(), nil)
  124. if err != nil {
  125. return nil, err
  126. }
  127. ws := &Ws2{
  128. cf: cf,
  129. conn: conn,
  130. }
  131. return ws, nil
  132. }
  133. i := strings.Index(hash, ":")
  134. if i <= 0 {
  135. return nil, errors.New("hash is invalid")
  136. }
  137. encryptMethod := hash[0:i]
  138. encryptKey := hash[i+1:]
  139. ci, ok := util.CipherMethod[encryptMethod]
  140. if !ok {
  141. return nil, errors.New("Unsupported encryption method: " + encryptMethod)
  142. }
  143. conn, _, err := (&websocket.Dialer{
  144. HandshakeTimeout: time.Duration(time.Millisecond * time.Duration(cf.ConnectTimeout)),
  145. }).Dial(u.String(), nil)
  146. if err != nil {
  147. return nil, err
  148. }
  149. var eiv []byte
  150. var div []byte
  151. if ci.IvLen > 0 {
  152. // 客户端 IV
  153. eiv = make([]byte, ci.IvLen)
  154. _, err = rand.Read(eiv)
  155. if err != nil {
  156. log.Println("[ws2 Client rand.Read ERROR]", err)
  157. return nil, err
  158. }
  159. // 发送 IV
  160. conn.SetWriteDeadline(time.Now().Add(time.Duration(cf.WriteWait) * time.Millisecond))
  161. if err := conn.WriteMessage(websocket.BinaryMessage, eiv); err != nil {
  162. log.Println("[ws2 Client conn.Write ERROR]", err)
  163. return nil, err
  164. }
  165. // 读取 IV
  166. err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(cf.ReadWait)))
  167. if err != nil {
  168. log.Println("[ws2 Client SetReadDeadline ERROR]", err)
  169. return nil, err
  170. }
  171. _, div, err = conn.ReadMessage()
  172. if err != nil {
  173. log.Println("[ws2 Client ReadFull ERROR]", err)
  174. return nil, err
  175. }
  176. }
  177. cipher, err := util.NewCipher(ci, encryptKey, eiv, div)
  178. if err != nil {
  179. log.Println("[ws2 NewCipher ERROR]", err)
  180. return nil, err
  181. }
  182. ws := &Ws2{
  183. cf: cf,
  184. conn: conn,
  185. cipher: cipher,
  186. }
  187. return ws, nil
  188. }
  189. // 发送数据到网络
  190. // 如果有加密函数的话会直接修改源数据
  191. func (c *Ws2) writeMessage(buf []byte) (err error) {
  192. if c.cipher != nil {
  193. c.cipher.Encrypt(buf, buf)
  194. }
  195. c.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.WriteWait)))
  196. return c.conn.WriteMessage(websocket.BinaryMessage, buf)
  197. }
  198. // 发送Auth信息
  199. // 建立连接后第一个发送的消息
  200. func (c *Ws2) WriteAuthInfo(channel string, auth []byte) (err error) {
  201. protoLen := len(PROTO)
  202. if protoLen > 0xFF {
  203. return errors.New("length of protocol over")
  204. }
  205. channelLen := len(channel)
  206. if channelLen > 0xFFFF {
  207. return errors.New("length of channel over")
  208. }
  209. // id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
  210. dlen := 2 + 1 + protoLen + 1 + 2 + channelLen + len(auth)
  211. start := 0
  212. buf := make([]byte, dlen)
  213. binary.BigEndian.PutUint16(buf[start:start+2], config.ID_AUTH)
  214. start += 2
  215. buf[start] = byte(protoLen)
  216. start++
  217. copy(buf[start:start+protoLen], []byte(PROTO))
  218. start += protoLen
  219. buf[start] = VERSION
  220. start++
  221. binary.BigEndian.PutUint16(buf[start:start+2], uint16(channelLen))
  222. start += 2
  223. copy(buf[start:start+channelLen], []byte(channel))
  224. start += channelLen
  225. copy(buf[start:], auth)
  226. return c.writeMessage(buf)
  227. }
  228. // 获取Auth信息
  229. // id(uint16)+proto(string)+version(uint8)+channel(string)+auth([]byte)
  230. func (c *Ws2) ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error) {
  231. defer func() {
  232. if r := recover(); r != nil {
  233. err = fmt.Errorf("recovered from panic: %v", r)
  234. return
  235. }
  236. }()
  237. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(c.cf.ReadWait)))
  238. if err != nil {
  239. return
  240. }
  241. _, msg, err := c.conn.ReadMessage()
  242. if err != nil {
  243. return
  244. }
  245. msgLen := len(msg)
  246. if msgLen < 4 {
  247. err = errors.New("message length less than 4")
  248. return
  249. }
  250. // 将读出来的数据进行解密
  251. if c.cipher != nil {
  252. c.cipher.Decrypt(msg, msg)
  253. }
  254. start := 0
  255. id := binary.BigEndian.Uint16(msg[start : start+2])
  256. if id != config.ID_AUTH {
  257. err = fmt.Errorf("wrong message id: %d", id)
  258. return
  259. }
  260. start += 2
  261. protoLen := int(msg[start])
  262. if protoLen < 2 {
  263. err = errors.New("wrong proto length")
  264. return
  265. }
  266. start++
  267. proto = string(msg[start : start+protoLen])
  268. if proto != PROTO {
  269. err = fmt.Errorf("wrong proto: %s", proto)
  270. return
  271. }
  272. start += protoLen
  273. version = msg[start]
  274. if version != VERSION {
  275. err = fmt.Errorf("require version %d, get version: %d", VERSION, version)
  276. return
  277. }
  278. start++
  279. channelLen := int(binary.BigEndian.Uint16(msg[start : start+2]))
  280. if channelLen < 2 {
  281. err = errors.New("wrong channel length")
  282. return
  283. }
  284. start += 2
  285. channel = string(msg[start : start+channelLen])
  286. start += channelLen
  287. auth = msg[start:]
  288. return
  289. }
  290. // 发送请求数据包到网络
  291. func (c *Ws2) WriteRequest(id uint16, cmd string, data []byte) error {
  292. // 为了区分请求还是响应包,命令字符串不能超过127个字节,如果超过则报错
  293. cmdLen := len(cmd)
  294. if cmdLen > 0x7F {
  295. return errors.New("length of command more than 0x7F")
  296. }
  297. dlen := 2 + 1 + cmdLen + len(data)
  298. buf := make([]byte, dlen) // 申请内存
  299. binary.BigEndian.PutUint16(buf[0:2], id)
  300. buf[2] = byte(cmdLen)
  301. copy(buf[3:], cmd)
  302. copy(buf[3+cmdLen:], data)
  303. return c.writeMessage(buf)
  304. }
  305. // 发送响应数据包到网络
  306. // 网络格式:[id, stateCode, data]
  307. func (c *Ws2) WriteResponse(id uint16, state uint8, data []byte) error {
  308. dlen := 2 + 1 + len(data)
  309. buf := make([]byte, dlen)
  310. binary.BigEndian.PutUint16(buf[0:2], id)
  311. buf[2] = state | 0x80
  312. copy(buf[3:], data)
  313. return c.writeMessage(buf)
  314. }
  315. // 发送ping包
  316. func (c *Ws2) WritePing(id uint16) error {
  317. buf := make([]byte, 2)
  318. binary.BigEndian.PutUint16(buf[0:2], id)
  319. return c.writeMessage(buf)
  320. }
  321. // 获取信息
  322. func (c *Ws2) ReadMessage(deadline int) (msgType conn.MsgType, id uint16, cmd string, state uint8, data []byte, err error) {
  323. err = c.conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(deadline)))
  324. if err != nil {
  325. return
  326. }
  327. _, msg, err := c.conn.ReadMessage()
  328. if err != nil {
  329. return
  330. }
  331. msgLen := len(msg)
  332. if msgLen < 2 {
  333. err = errors.New("message length less than 2")
  334. return
  335. }
  336. // 将读出来的数据进行解密
  337. if c.cipher != nil {
  338. c.cipher.Decrypt(msg, msg)
  339. }
  340. id = binary.BigEndian.Uint16(msg[0:2])
  341. // ping信息
  342. if msgLen == 2 {
  343. msgType = conn.PingMsg
  344. return
  345. }
  346. if id > config.ID_MAX {
  347. err = fmt.Errorf("wrong message id: %d", id)
  348. return
  349. }
  350. cmdx := msg[2]
  351. if (cmdx & 0x80) == 0 {
  352. // 请求包
  353. msgType = conn.RequestMsg
  354. cmdLen := int(cmdx)
  355. cmd = string(msg[3 : cmdLen+3])
  356. data = msg[cmdLen+3:]
  357. return
  358. } else {
  359. // 响应数据包
  360. msgType = conn.ResponseMsg
  361. state = cmdx & 0x7F
  362. data = msg[3:]
  363. return
  364. }
  365. }
  366. // 获取远程的地址
  367. func (c *Ws2) RemoteAddr() net.Addr {
  368. return c.conn.RemoteAddr()
  369. }
  370. // 获取本地的地址
  371. func (c *Ws2) LocalAddr() net.Addr {
  372. return c.conn.LocalAddr()
  373. }
  374. func (c *Ws2) Close() error {
  375. return c.conn.Close()
  376. }