GeeRPC 简析

前言

本文是记录在学习 GeeRPC 项目中的一些要点和概念,在代码部分只会给出最核心的代码,而一些外围代码,例如错误处理、一些并发细节、一些接口定义等都会被忽略。适合从一个宏观层面来理解 GeeRPC,因为 GeeRPC 原文实际上是从细节出发,难以从一个宏观视角理解,容易在「森林」中迷路。

完整的源代码还请参考仓库 7daysgolang。

什么是 RPC

RPC 全称 Remote Procedure Call。简单来说,就是将向服务区获取资源这一操作简化成一个函数调用,就像在本地调用函数一样,给入一些参数,返回一个结果。对于使用者来说,他就像本地函数调用一样简单。而本质上这是一次客户端向服务端的网络请求。

一个典型的 RPC 调用如下:

err = client.Call("Service.Method", args, &reply)

其中客户端给定的参数是服务名和方法名的字符串,并给入一组参数 args,服务端会将结果写入 reply,并返回一个错误信息 err

GeeRPC 的全局视角

从最简化的视角来看,一个 RPC 调用需要一个客户端和服务端,客户端需要发送要调用的方法名,方法的参数。服务端在收到客户端的请求之后,找到要调用的方法,并将参数传入,最终将结果返回给客户端。

从最底层的角度看,客户端需要把调用的方法名、参数序列化成一个二进制序列发送给服务端,而服务端需要解析这个二进制序列,还原成调用的方法名(字符串)、参数(特定的类型)。所以在这部分,至少需要一个编解码器,来对二进制序列进行序列化和还原。实际上这就是一种协议,来规定双方是如何进行交流通信的。

客户端收到返回值(特定类型)之后,还需要再进行一次解码,将数据还原成可读的状态。

在 Go 中,为了简化这种来回编解码的过程,和简化代码接口,将参数和返回值都作为 RPC 调用的参数。其中返回值必须是一个指针,也就是在函数内的修改是可见的,在外部才能获取到这个返回值。而为了符合 Go 的代码设计哲学,返回值代表的是错误信息。

GeeRPC 是大体上参考标准库 net/rpc 的实现,并在此基础上添加了一些额外的功能。

图片1

消息的序列化和反序列化

在网络上传输的最终是二进制数据,所以我们要对 Go 中的变量序列化成一个二进制序列,在本项目中使用了 Go 自带的序列化库 gob,其可以将 Go 中的数据类型编码成二进制序列,最终再将这个二进制序列还原成 Go 中的数据类型,一个最简单的示例如下:

str_enc := "Hello World"
var str_dec string

// 创建编解码器
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
decoder := gob.NewDecoder(&buffer)

// 编码,编码之后二进制序列写入到 buffer 中
encoder.Encode(str_enc)
// 解码,从 buffer 中读取二进制序列解码
decoder.Decode(&str_dec)

在真正编解码之前,我们需要配置一个协议来表示这是一个 RPC 请求,不能只要一个二进制序列到来就进行解析,我们定义一个 Option,其中包含两个字段 MagicNumber 用来表示开始序列,只有以这个序列开始,才表示这是一个 RPC 调用,第二个字段 CodecType 来表示编解码的方法。(也许在未来会添加新的编解码方法,不使用标准库的 gob,提供了一种拓展性)

type Option struct {
    MagicNumber int
    CodecType   string
}

随后,我们定义一个 RPC 调用是怎么构成的,我们可以将 RPC 调用拆分为 header 和 body。其中 header 代表了一些 RPC 调用的元信息,方便客户端和服务端从一开始就能获取到一些关键的信息,例如要调用的方法名、调用序号、错误信息等等。body 则代表 RPC 调用的实际数据,例如参数、返回值等等。

type Header struct {
    ServiceMethod string
    Seq           uint64
    Error         string
}

不定义 body 是因为 body 就是剩下的二进制序列,没有必要再定义一个。

我们再定义一下一个编解码器所需要的接口:

  • 他首先需要一个 Close() 函数,因为编解码器的 buffer 实际上来自于网络连接,在发生错误时,他要关闭这个网络连接。
  • 然后是 ReadHeader(*Header)ReadBody(interface{}) 接口,分别读取 header 和 body。
  • 最后实现一个 Write(*Header, interface{}) 来将二进制数据写入到网络连接中。

接着再定义编解码器的字段:

  • 一个已经建立的网络连接
  • 一个来自网络连接的缓冲区
  • 解码器
  • 编码器

所以我们可以编写以下代码:

type GobCodec struct {
    conn io.ReadWriteCloser
    buf  *bufio.Writer
    dec  *gob.Decoder
    enc  *gob.Encoder
}

// conn 是一个已经建立好的网络连接
func NewGobCodec(conn io.ReadWriteCloser) Codec {
    buf := bufio.NewWriter(conn) // 基于网络连接创建缓冲区
    return &GobCodec{
        conn: conn,
        buf:  buf,
        dec:  gob.NewDecoder(conn),
        enc:  gob.NewEncoder(buf),
    }
}

func (c *GobCodec) ReadHeader(h *Header) error {
    return c.dec.Decode(h)
}

func (c *GobCodec) ReadBody(body interface{}) error {
    return c.dec.Decode(body)
}

func (c *GobCodec) Close() error {
    return c.conn.Close()
}

func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
    defer {
	    // 调用 Flush() 确保网络数据的发送
        c.buf.Flush()
    }
    // 编码写入缓冲区,实际上会进行网络数据的发送
    c.enc.Encode(h)
    c.enc.Encode(body)
    return
}

需要注意的是,要编解码的类,类和字段只有是导出才能够被编解码。

服务

在完成消息的序列化和反序列之后,我们需要定义,什么是一个服务(service),服务端如何根据到来的信息选择所要调用的方法。当然,最简单的方法是我们可以使用硬解码,例如:

// 服务端收到信息...
if ServiceMethod == "Foo" {
	// 调用 Foo
} else if ServiceMethod == "Bar" {
	// 调用 Bar
}
// 将结果返回给客户端...

然而这样的方法是不灵活的,如果存在 100 个调用,这个 if-else 就要有 100 个分支,显然不够优雅。在这里我们就要使用到 Go 的反射机制了,反射机制简单来说就是一种动态的解析,假设代码中存在这个方法或者存在这个类,那么 Go 就能找到,不用手动硬编码,例如:

// method 是一个 string 到方法调用的映射
method = make(map[string]reflect.Method)

// 根据反射注册函数,假设 service 是某一个服务类
// 通过反射可以自动获取该类的所有函数
for i := 0; i < service.NumMethod(); i++ {
	m = service.Method(i)
	method[m.Name] = m
}

// 根据反射找到该函数
m = method[methodName]
if m != nil {
	// 如果找到,则调用
	m.call(...)
} else {
	return errors.New("can't find method " + methodName)
}

对于要调用的函数可以使用反射,对于函数的参数,同样可以使用反射,而不用硬编码,所以定义一个 methodType 来代表一个可调用的函数:

type methodType struct {
    method    reflect.Method // 函数对象
    ArgType   reflect.Type   // 参数类型
    ReplyType reflect.Type   // 返回值类型
}

并定义一个 service 来代表一个服务类:由于服务类可能是有状态的(有内部字段),所以需要一个类的实例。

type service struct {
    name   string                 // 服务名
    typ    reflect.Type           // 注册服务的类的类型
    rcvr   reflect.Value          // 注册服务的类的实例
    method map[string]*methodType // 类的所有方法
}

接着,我们就可以实现其所需要的一些函数,对于 methodType 来说,我们需要一个根据所记录的类型创建变量的函数,methodType 只记录了函数的签名,而没有实例,当我们真正需要构建一个调用时,就需要创建这些参数的实例。

// 创建一个参数类型的实例
func (m *methodType) newArgv() reflect.Value {
	var argv reflect.Value
	// 指针和值类型需要区分,有细微区别
	if m.ArgType.Kind() == reflect.Ptr {
		argv = reflect.New(m.ArgType.Elem())
	} else {
		argv = reflect.New(m.ArgType).Elem()
	}
	return argv
}

// 创建一个返回值类型的实例
func (m *methodType) newReplyv() reflect.Value {
	// 返回值必须是指针
	replyv := reflect.New(m.ReplyType.Elem())
	// 这两种特殊的数据结构需要初始化
	switch m.ReplyType.Elem().Kind() {
		case reflect.Map:
			replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
		case reflect.Slice:
			replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
	}
	return replyv
}

对于 service 来说,他需要基于一个给定的服务类实例注册服务,另一个职责则是调用指定的方法:

// 必须保证函数签名是第一个参数是 argType,第二个参数是 replyType
// 这里忽略了一些检查的代码
func (s *service) registerMethods() {
	s.method = make(map[string]*methodType)
	// 获取所有类的方法
	for i := 0; i < s.typ.NumMethod(); i++ {
		method := s.typ.Method(i)
		mType := method.Type
		argType, replyType := mType.In(1), mType.In(2)
		s.method[method.Name] = &methodType{
			method:    method,
			ArgType:   argType,
			ReplyType: replyType,
		}
	}
}

// 调用方法,需要给定调用的函数签名和参数
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
	// 获取可调用的函数对象
	f := m.method.Func
	returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
	return nil
}

这样,服务端在收到一个 RPC 调用请求之后,解析出函数名字符串,然后在 service 中的 map 中查找,找到其对应的函数前面,并创建对应的变量,将二进制数据解析写入到该变量中,随后调用 call 来进行实际的调用。

服务端

到现在,就可以进行服务端代码的编写了,服务端和客户端是两个职责非常多的类,一方面要使用编解码器进行解析,另一方面还要对可调用的方法进行维护,并考虑各种错误情况。

但是,其最核心的流程可以简化为以下几步:

  • 监听可能到来的客户端连接(TCP 连接)
  • 建立连接之后,解析并响应来自服务端的请求信息(编解码器和服务反射)

第一步最简单,Go 已经给了非常方便的 tcp 接口来实现监听,对于每个来到的 tcp 连接,我们都建立一个单独的 goroutine 来处理客户端的请求,本文就不再赘述。

第二步解析请求信息的核心步骤可以分为:

  • 查看协议头,是否是一个 RPC 请求,如果是,进入一个死循环等待客户端的请求,如果不是,断开连接
  • 读取 RPC 请求,分为读取 header 和 body
    • 根据 header 在注册的服务中查找,如果没找到,则返回服务未找到错误
    • 根据 body 创建相应的参数实例,并调用 service.call 来进行实际调用
  • 获取结果之后,发送响应给客户端

这里查看协议头的编解码器使用的是 json,和 gob 并没有什么不同,只是编码是将数据结构编码成 json 字符串,解码是根据 json 字符串复原。

在具体进入流程之前,还有一些前期工作要做,第一个是类的定义,第二个是服务的注册:

// 服务端只有一个字段,代表服务名到服务实例的注册
type Server struct {
	serviceMap Map[string]*service
}

func (server *Server) Register(rcvr interface{}) error {
	s := newService(rcvr)
	server.serviceMap.LoadOrStore(s.name, s)
	return nil
}

// 创建一个服务
func newService(rcvr interface{}) *service {
	s := new(service)
	// 根据 rcvr 创建一个相同类型的值
	s.rcvr = reflect.ValueOf(rcvr)
	// Indirect 会不断解引用,找到最终的那个值
	// 假设 rcvr 是多层嵌套的指针,也会找到最终所指向的值
	s.name = reflect.Indirect(s.rcvr).Type().Name()
	// 获取 rcvr 的类型
	s.typ = reflect.TypeOf(rcvr)
	s.registerMethods()
	return s
}

现在我们就可以到具体的服务器流程中了

func (server *Server) ServeConn(conn io.ReadWriteCloser) {
	defer
		conn.Close()
	var opt Option
	// 查看协议头
	json.NewDecoder(conn).Decode(&opt)
	// 协议头是否满足
	if opt.MagicNumber != MagicNumber {
		return
	}
	// 创建编解码器
	cc := codec.NewGobCodec(conn)
	server.serveCodec(cc)
}

func (server *Server) serveCodec(cc codec.Codec) {
	for {
		req, _ := server.readRequest(cc)
		// 省略请求错误的处理...
		go server.handleRequest(cc, req)
	}
	cc.Close()
}

在完成连接的建立和初步的协议检查之后,就进入了一个死循环,分为读取请求和处理请求,我们先看读取请求:

func (server *Server) readRequest(cc codec.Codec) (*request, error) {
	h, _ := server.readRequestHeader(cc)
	// 创建 request 实例,并给各个字段赋值
	req := &request{h: h}
	// 在注册的服务中查找,获取服务类实例和函数签名
	req.svc, req.mtype, _ = server.findService(h.ServiceMethod)
	req.argv = req.mtype.newArgv()
	req.replyv = req.mtype.newReplyv()

	// Interface() 的作用是将反射值拷贝并转化类型为 interface{}
	argvi := req.argv.Interface()
	// 转化为指针之后才能将数据写入到 argvi 中
	if req.argv.Type().Kind() != reflect.Ptr {
		argvi = req.argv.Addr().Interface()
	}

	// 写入数据
	cc.ReadBody(argvi)
	return req, nil
}

func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
	var h codec.Header
	cc.ReadHeader(&h)
	return &h, nil
}

func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
	dot := strings.LastIndex(serviceMethod, ".")
	serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
	svci, _ := server.serviceMap.Load(serviceName)
	svc = svci.(*service)
	mtype = svc.method[methodName]
	return svc, mtype, err
}

在读取完请求之后,所有的参数都会在本地创建对应的实例,然后再进行处理请求:

func (server *Server) handleRequest(cc codec.Codec, req *request) {
	req.svc.call(req.mtype, req.argv, req.replyv)
	server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}

func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}) {
	cc.Write(h, body)
}

服务端的主要功能就是这一些,在所给的代码中,忽略了错误检查等等外围代码。

客户端

对于客户端来说,其职责就简单地多,只需要将参数打包,编码成所需的格式,然后再发送到网络即可,大体可以分为以下几步:

  • 设置协议头并建立连接
  • 将参数打包,指定服务和方法名,编码之后发送请求
  • 等待响应,在接收响应之后,解码到来的数据,并在本地创建相应的实例

为了支持并发,geerpc 中使用了一个协程不断等待请求的到来,另一个协程来发送请求,但是简单理解还是可以看作是串行的。

我们首先来看客户端类如何设计,需要什么字段:

// 一个 Call 代表一次调用
type Call struct {
	ServiceMethod string      // format "<service>.<method>"
	Args          interface{} // arguments to the function
	Reply         interface{} // reply from the function
	Error         error       // if error occurs, it will be set
}

type Client struct {
	cc       codec.Codec
	opt      *Option
	header   codec.Header
	closing  bool             // user has called Close
	shutdown bool             // server has told us to stop
}

这里简化了一些字段,是为了并发实现的,这里忽略。首先建立连接和确认协议。

func NewClient(conn net.Conn, opt *Option) (*Client, error) {
	cc := codec.NewGobCodec(conn)
	// 将协议头发送给服务器
	// 如果协议不对,连接会被服务器中断
	json.NewEncoder(conn).Encode(opt)
	client := &Client{
		cc:      cc,
		opt:     opt,
	}
	return client, nil
}

随后就要发送请求,这里也删去了并发的一些内容,所以可能看起来函数有点奇怪:

func (client *Client) send(call *Call) {
	// 创建请求头
	client.header.ServiceMethod = call.ServiceMethod
	client.header.Seq = seq
	client.header.Error = ""

	// 编码并发送数据
	client.cc.Write(&client.header, call.Args)
}

最后就是接收响应了,实际上是有另一个死循环不断执行 receive 函数,本文只需要理解核心代码即可。

// 将数据写入到 Call 里
func (client *Client) receive(call *Call) {
	var h codec.Header
	client.cc.ReadHeader(&h)
	client.cc.ReadBody(call.Reply)
}

// 主要流程
type Args struct{ Num1, Num2 int }
fun main() {
	conn, _ := net.Dial("tcp", "127.0.0.1:1111")
	var opt = &Option{
		MagicNumber: MagicNumber,
		CodecType:   codec.GobType,
	}
	args := &Args{Num1: i, Num2: i * i}
	call := &Call{
		ServiceMethod: "Service.Hello",
		Args:          args,
		Reply:         &reply,
	}
	client = NewClient(conn, &opt)
	client.send(&call)
	client.receive(&call)

	reply := call.Reply.(int)
}

到此,geerpc 的主要框架就梳理好了!

外围服务

除了主要框架的基本功能之外,geerpc 还实现了一些其他功能,例如支持 HTTP,负载均衡

HTTP

支持 HTTP 的出发点在于能够让 RPC 更加灵活、可扩展,并易于与现有的 Web 基础设施集成。其原理就是服务端和客户端的第一次信息交换是通过 HTTP 协议进行的,而一旦双方确认之后,后续的信息传输是在一个底层的 TCP 中实现的。

在客户端的修改比较简单,只需要单独增加一个 HTTP 协议的客户端:

func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) {
	_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))

	// 确认 HTTP 状态
	// resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
	// ...

	return NewClient(conn, opt)
}

此时我们手动构造了一个字符串发送给服务端,服务端会视为是一个 HTTP 请求,并响应一个 HTTP 协议的回复,后续就和原来一样了。

这个关键在于,第一次的 HTTP 交流是为了建立 TCP 连接,在建立了 TCP 连接之后,HTTP 也就不需要了。

// 外围函数:在 Go 中注册 HTTP 响应函数
func (server *Server) HandleHTTP() {
	http.Handle(defaultRPCPath, server)
}

// 核心函数,在 HTTP 请求到来之后会调用
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	if req.Method != "CONNECT" {
		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
		w.WriteHeader(http.StatusMethodNotAllowed)
		_, _ = io.WriteString(w, "405 must CONNECT\n")
		return
	}
	// 获取底层 TCP 连接
	// 执行 Hijack() 之后就不能再用 http.ResponseWriter 响应请求了
	conn, _, _ := w.(http.Hijacker).Hijack()
	_, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")

	// 后续就和原来一模一样了
	server.ServeConn(conn)
}

负载均衡

负载均衡的本质就是有多个一样的服务器可以请求,多个客户端尽可能平均地将请求分布在这多个服务器上。geerpc 的负载均衡主要在客户端上实现,服务端并没有做什么修改。

但负载均衡的实现,一方面是可以通过客户端解决,例如在客户端上储存多个服务器的地址,采用一些算法均匀地将请求分布在这些服务器上面。另一方面是可以在服务端设定一些机制,让客户端去请求另一台服务器。

还有一种更方便的方法,建立一个中心节点(代理),负责分发这些请求,所有的客户端请求都只会给这个代理服务器,这个代理服务器来决定最终这个请求要转发给哪一台服务器,这样的好处是易于管理。

在 geerpc 中就只是在客户端做了修改,如果客户端是恶意的,则负载均衡就没有意义了。

func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
	// 根据负载均衡模式,获取下一个服务器的地址
	// 最终的调用就会发送给这个地址的服务器
	rpcAddr, _ := xc.d.Get(xc.mode)
	return xc.call(rpcAddr, ctx, serviceMethod, args, reply)
}

// 最简单的有随机选择和轮转选择
func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) {
	n := len(d.servers)
	switch mode {
		case RandomSelect:
			return d.servers[d.r.Intn(n)], nil
		case RoundRobinSelect:
			s := d.servers[d.index%n]
			d.index = (d.index + 1) % n
			return s, nil
	}
}

目前来说,客户端需要主动维护所有的服务器列表,但是这样的信息对于客户端来说是难以获得的。更好的方式是客户端只需要知道代理节点,代理节点来维护所有的服务器列表,并且可以移除那些已经下线的服务器。维护服务器列表的核心在于心跳连接,代理节点会定期向所有服务器发送心跳信息,假如什么时候收不到心跳信息,则代表该服务器下线,移除这个服务器地址,不会再发送给客户端。

这部分实际上就是一些额外的业务代码,理解原理之后,就是一些代码的编写,本文就不再赘述了。

下一篇

评论 | 0条评论