Pārlūkot izejas kodu

fix auth function

Joyit 1 mēnesi atpakaļ
vecāks
revīzija
50895f5d4b
5 mainītis faili ar 47 papildinājumiem un 38 dzēšanām
  1. 13 10
      examples/client-tcp2.go
  2. 13 10
      examples/client-ws2.go
  3. 14 11
      examples/server.go
  4. 5 5
      hub.go
  5. 2 2
      type.go

+ 13 - 10
examples/client-tcp2.go

@@ -26,17 +26,20 @@ func main() {
 
 	hub := tinymq.NewHub(cf, localChannel, func(channel string, hostType tinymq.HostType) (hostInfo *tinymq.HostInfo, err error) {
 		return host, nil
-	}, func(proto string, version uint8, channel string, remoteAuth []byte) (auth []byte) {
+	}, func(client bool, proto string, version uint8, channel string, remoteAuth []byte) (auth []byte) {
+		log.Println("[AuthFunc]", client, proto, version, channel, string(remoteAuth))
+		return []byte("tinymq-client")
 		// 从 remoteAuth 是否为空来判断是否需要返回信息
-		if len(remoteAuth) <= 0 {
-			// 客户端调用,返回验证信息
-			return []byte("tinymq")
-		} else {
-			// 服务端调用,返回验证token,或者其他信息
-			return nil
-		}
-	}, func(proto string, version uint8, channel string, auth []byte) bool {
-		return true
+		// if len(remoteAuth) <= 0 {
+		// 	// 客户端调用,返回验证信息
+		// 	return []byte("tinymq")
+		// } else {
+		// 	// 服务端调用,返回验证token,或者其他信息
+		// 	return nil
+		// }
+	}, func(client bool, proto string, version uint8, channel string, auth []byte) bool {
+		log.Println("[CheckAuthFunc]", client, proto, version, channel, string(auth))
+		return string(auth) == "tinymq-server"
 	}, func(conn *tinymq.Line) {
 		log.Println("connect state", conn.Channel(), conn.State(), time.Since(conn.Updated()))
 	}, nil)

+ 13 - 10
examples/client-ws2.go

@@ -29,17 +29,20 @@ func main() {
 
 	hub := tinymq.NewHub(cf, localChannel, func(channel string, hostType tinymq.HostType) (hostInfo *tinymq.HostInfo, err error) {
 		return host, nil
-	}, func(proto string, version uint8, channel string, remoteAuth []byte) (auth []byte) {
+	}, func(client bool, proto string, version uint8, channel string, remoteAuth []byte) (auth []byte) {
+		log.Println("[AuthFunc]", client, proto, version, channel, string(remoteAuth))
+		return []byte("tinymq-client")
 		// 从 remoteAuth 是否为空来判断是否需要返回信息
-		if len(remoteAuth) <= 0 {
-			// 客户端调用,返回验证信息
-			return []byte("tinymq")
-		} else {
-			// 服务端调用,返回验证token,或者其他信息
-			return nil
-		}
-	}, func(proto string, cversion uint8, hannel string, auth []byte) bool {
-		return true
+		// if len(remoteAuth) <= 0 {
+		// 	// 客户端调用,返回验证信息
+		// 	return []byte("tinymq")
+		// } else {
+		// 	// 服务端调用,返回验证token,或者其他信息
+		// 	return nil
+		// }
+	}, func(client bool, proto string, version uint8, channel string, auth []byte) bool {
+		log.Println("[CheckAuthFunc]", client, proto, version, channel, string(auth))
+		return string(auth) == "tinymq-server"
 	}, func(conn *tinymq.Line) {
 		log.Println("connect state", conn.Channel(), conn.State(), time.Since(conn.Updated()))
 	}, nil)

+ 14 - 11
examples/server.go

@@ -23,17 +23,20 @@ func main() {
 	var hub *tinymq.Hub
 
 	hub = tinymq.NewHub(cf, localChannel,
-		nil, func(proto string, version uint8, channel string, remoteAuth []byte) (auth []byte) {
-			// 从 remoteAuth 是否为空来判断是否需要返回信息
-			if len(remoteAuth) <= 0 {
-				// 客户端调用,返回验证信息
-				return []byte("tinymq")
-			} else {
-				// 服务端调用,返回验证token,或者其他信息
-				return nil
-			}
-		}, func(proto string, version uint8, channel string, auth []byte) bool {
-			return string(auth) == "tinymq"
+		nil, func(client bool, proto string, version uint8, channel string, remoteAuth []byte) (auth []byte) {
+			log.Println("[AuthFunc]", client, proto, version, channel, string(remoteAuth))
+			return []byte("tinymq-server")
+			// // 从 remoteAuth 是否为空来判断是否需要返回信息
+			// if len(remoteAuth) <= 0 {
+			// 	// 客户端调用,返回验证信息
+			// 	return []byte("tinymq")
+			// } else {
+			// 	// 服务端调用,返回验证token,或者其他信息
+			// 	return nil
+			// }
+		}, func(client bool, proto string, version uint8, channel string, auth []byte) bool {
+			log.Println("[CheckAuthFunc]", client, proto, version, channel, string(auth))
+			return string(auth) == "tinymq-client"
 		}, func(conn *tinymq.Line) {
 			log.Println("[Connect state change]", conn.Channel(), conn.State(), time.Since(conn.Updated()))
 			if conn.State() == tinymq.Connected {

+ 5 - 5
hub.go

@@ -588,12 +588,12 @@ func (h *Hub) BindForServer(info *HostInfo) (err error) {
 			conn.Close()
 			return
 		}
-		if !h.checkAuthFunc(proto, version, channel, auth) {
+		if !h.checkAuthFunc(false, proto, version, channel, auth) {
 			conn.Close()
 			return
 		}
 		// 发送频道信息
-		if err := conn.WriteAuthInfo(h.channel, h.authFunc(proto, version, channel, auth)); err != nil {
+		if err := conn.WriteAuthInfo(h.channel, h.authFunc(false, proto, version, channel, auth)); err != nil {
 			log.Println("[WriteAuthInfo ERROR]", err)
 			conn.Close()
 			return
@@ -704,7 +704,7 @@ func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err e
 		localChannel = localChannel + "?proxy=" + host.Host
 	}
 	// 发送验证信息
-	if err := conn.WriteAuthInfo(localChannel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
+	if err := conn.WriteAuthInfo(localChannel, h.authFunc(true, runProto, host.Version, channel, nil)); err != nil {
 		log.Println("[WriteAuthInfo ERROR]", err)
 		conn.Close()
 		host.Errors++
@@ -739,7 +739,7 @@ func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err e
 		return err
 	}
 	// 检查验证是否合法
-	if !h.checkAuthFunc(proto, version, channel, auth) {
+	if !h.checkAuthFunc(true, proto, version, channel, auth) {
 		err = fmt.Errorf("[checkAuthFunc ERROR] in proto: %s, version: %d, channel: %s, auth: %s", proto, version, channel, string(auth))
 		log.Println(err)
 		conn.Close()
@@ -849,7 +849,7 @@ func (h *Hub) checkConnect() {
 func NewHub(
 	cf *config.Config,
 	channel string,
-	// 客户端需要用的函数 (服务端可空)
+	// 客户端需要用的函数,提供连接的主机信息 (服务端可空)
 	connectHostFunc ConnectHostFunc,
 	// 验证函数,获取认证信息,用于发送给对方
 	authFunc AuthFunc,

+ 2 - 2
type.go

@@ -165,10 +165,10 @@ func (h *HostInfo) Key() string {
 type ConnectHostFunc func(channel string, hostType HostType) (hostInfo *HostInfo, err error)
 
 // 获取认证信息
-type AuthFunc func(proto string, version uint8, channel string, remoteAuth []byte) (auth []byte)
+type AuthFunc func(client bool, proto string, version uint8, channel string, remoteAuth []byte) (auth []byte)
 
 // 认证合法性函数
-type CheckAuthFunc func(proto string, version uint8, channel string, auth []byte) bool
+type CheckAuthFunc func(client bool, proto string, version uint8, channel string, auth []byte) bool
 
 // 验证发送的数据条件是否满足
 type CheckConnectOkFunc = func(line *Line, data *GetData) bool