Joyit преди 1 месец
родител
ревизия
07f18630e4
променени са 3 файла, в които са добавени 21 реда и са изтрити 10 реда
  1. 0 1
      conn/conn.go
  2. 9 4
      hub.go
  3. 12 5
      line.go

+ 0 - 1
conn/conn.go

@@ -31,7 +31,6 @@ func (t MsgType) String() string {
 type Connect interface {
 	WriteAuthInfo(channel string, auth []byte) (err error)
 	ReadAuthInfo() (proto string, version uint8, channel string, auth []byte, err error)
-	// WriteChannel(data []byte) error
 	WriteRequest(id uint16, cmd string, data []byte) error
 	WriteResponse(id uint16, state uint8, data []byte) error
 	WritePing(id uint16) error

+ 9 - 4
hub.go

@@ -205,7 +205,7 @@ func (h *Hub) ChannelToFunc(fn func(string) bool) {
 // 从 channel 获取连接
 func (h *Hub) ChannelToLine(channel string) (line *Line) {
 	h.lines.Range(func(id int, l *Line) bool {
-		if l.channel == channel {
+		if l.IsChannelEqual(channel) {
 			line = l
 			return false
 		}
@@ -451,7 +451,6 @@ func (h *Hub) GetWithStruct(gd *GetData, backFunc GetBackFunc) (count int) {
 			return
 		}
 	}
-	// return
 }
 
 // 请求频道并获取数据,采用回调的方式返回结果
@@ -602,7 +601,7 @@ func (h *Hub) BindForServer(info *HostInfo) (err error) {
 		// 将连接加入现有连接中
 		done := false
 		h.lines.Range(func(id int, line *Line) bool {
-			if line.state == Disconnected && line.host == nil && line.ChannelEqualWithoutPrefix(channel) {
+			if line.state == Disconnected && line.host == nil && line.IsChannelEqual(channel) {
 				line.Start(conn, nil)
 				done = true
 				return false
@@ -698,8 +697,14 @@ func (h *Hub) ConnectToServer(channel string, force bool, host *HostInfo) (err e
 		done = true
 		return errors.New("timeout")
 	}
+
+	// 如果 host 是代理,将代理信息添加到channel中
+	localChannel := h.channel
+	if host.Proxy {
+		localChannel = localChannel + "?proxy=" + host.Host
+	}
 	// 发送验证信息
-	if err := conn.WriteAuthInfo(h.channel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
+	if err := conn.WriteAuthInfo(localChannel, h.authFunc(runProto, host.Version, channel, nil)); err != nil {
 		log.Println("[WriteAuthInfo ERROR]", err)
 		conn.Close()
 		host.Errors++

+ 12 - 5
line.go

@@ -84,15 +84,22 @@ func (c *Line) RemoveChannelName() {
 	}
 }
 
-// 频道是否相等,不包括@前面部分
-func (c *Line) ChannelEqualWithoutPrefix(channel string) bool {
+// 频道是否相等,不包括@前面部分和参数部分
+func (c *Line) IsChannelEqual(channel string) bool {
 	if inx := strings.Index(channel, "@"); inx >= 0 {
 		channel = channel[inx+1:]
 	}
-	if inx := strings.Index(c.channel, "@"); inx >= 0 {
-		return channel == c.channel[inx+1:]
+	if inx := strings.Index(channel, "?"); inx > 0 {
+		channel = channel[0:inx]
+	}
+	origin := c.channel
+	if inx := strings.Index(origin, "@"); inx >= 0 {
+		origin = origin[inx+1:]
+	}
+	if inx := strings.Index(origin, "?"); inx > 0 {
+		origin = origin[0:inx]
 	}
-	return channel == c.channel
+	return channel == origin
 }
 
 // 获取远程的地址