Golang实现自己的RPC框架
发布日期:2021-06-30 20:40:41 浏览次数:2 分类:技术文章

本文共 6185 字,大约阅读时间需要 20 分钟。

rpc/session.go

package rpcimport (	"encoding/binary"	"io"	"net")// 编写数据会话中读写// 会话连接的结构体type Session struct {	conn net.Conn}// 创建新连接func NewSession(conn net.Conn) *Session {	return &Session{conn: conn}}// 向连接中写数据func (s Session) Write(data []byte) error {	// 4字节头+数据长度切片	buf := make([]byte, 4+len(data))	// 写入头部数据,记录数据长度	// binary 只认固定长度的类型,所以使用了uint32,而不是直接写入	binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))	copy(buf[:4], data)	_, err := s.conn.Write(buf)	if err != nil {		return err	}	return nil}// 从连接中读数据func (s Session) Read() ([]byte, error) {	// 读取头部长度	header := make([]byte, 4)	// 按头部长度, 读取头部数据	_, err := io.ReadFull(s.conn, header)	if err != nil {		return nil, err	}	// 读取数据长度	dataLen := binary.BigEndian.Uint32(header)	// 按照数据长度去读取数据	data := make([]byte, dataLen)	_, err = io.ReadFull(s.conn, data)	if err != nil {		return nil, err	}	return data, nil}

rpc/session_test.go

package rpcimport (	"fmt"	"net"	"sync"	"testing")func TestSession_ReadWrite(t *testing.T) {	// 定义监听IP和端口	addr := "127.0.0.1:8080"	// 定义传输的数据	my_data := "hello world"	// 等待组	wg := sync.WaitGroup{}	// 协程,一个读,一个写	wg.Add(2)	// 写数据协程	go func() {		defer wg.Done()		// 创建tcp连接		lis, err := net.Listen("tcp", addr)		if err != nil {			t.Fatal(err)		}		conn,_ := lis.Accept()		s := Session{conn: conn}		// 写数据		err = s.Write([]byte(my_data))		if err != nil {			t.Fatal(err)		}	}()	// 读数据协程	go func() {		defer wg.Done()		conn, err := net.Dial("tcp", addr)		if err != nil {			t.Fatal(err)		}		s := Session{conn: conn}		// 读数据		data, err := s.Read()		if err != nil {			t.Fatal(err)		}		if string(data) != my_data {			t.Fatal(err)		}		fmt.Println(string(data))	}()	wg.Wait()}

 

rpc/codec.go

package rpcimport (	"bytes"	"encoding/gob")// 定义数据格式和编解码type RPCData struct {	// 访问的函数	Name string	// 访问时传的参数	Args []interface{}}// 编码func encode(data RPCData) ([]byte, error) {	var buf bytes.Buffer	// 得到字节数组的编码器	bufEnc := gob.NewEncoder(&buf)	// 对数据进行编码	bufEnc.Encode(data)	if err := bufEnc.Encode(data); err != nil {		return nil, err	}	return buf.Bytes(), nil}// 解码func decode(b []byte) (RPCData, error) {	buf := bytes.NewBuffer(b)	// 返回字节数组的解码器	bufDec := gob.NewDecoder(buf)	var data RPCData	// 对数据解码	if err := bufDec.Decode(&data); err != nil {		return data, nil	}	return data, nil}

rpc/server.go

package rpcimport (	"fmt"	"net"	"reflect")// 声明服务端type Server struct {	// 地址	addr string	// 服务端维护的函数名到函数反射值的map	funcs map[string]reflect.Value}// 创建服务端对象func NewServer(addr string) *Server {	return &Server{addr: addr, funcs:make(map[string]reflect.Value)}}// 服务端绑定注册方法// 将函数名与函数真正实现对应起来// 第一个参数为函数名, 第二个传入真正的函数func (s *Server) Register(rpcName string, f interface{}) {	if _, ok := s.funcs[rpcName]; ok {		return	}	// map中没有值,则将映射添加进map,便于调用	fVal := reflect.ValueOf(f)	s.funcs[rpcName] = fVal}// 服务端等待调用func (s *Server) Run() {	// 监听	lis, err := net.Listen("tcp", s.addr)	if err != nil {		fmt.Printf("监听%s err:%v", s.addr, err)		return	}	for {		// 拿到连接		conn, err := lis.Accept()		if err != nil {			fmt.Printf("accept err:%v", err)			return		}		// 创建会话		srvSession := NewSession(conn)		// RPC 读取数据		b, err := srvSession.Read()		if err != nil {			fmt.Printf("read err:%v", err)			return		}		// 对数据解码		rpcData, err := decode(b)		if err != nil {			fmt.Printf("decode err:%v", err)			return		}		// 根据读取到的数据的Name,得到调用的函数名		f, ok := s.funcs[rpcData.Name]		if !ok {			fmt.Printf("函数名%s不存在", rpcData.Name)		}		// 解析遍历客户端出来的参数, 放到一个数组中		inArgs := make([]reflect.Value, 0, len(rpcData.Args))		for _, arg := range rpcData.Args {			inArgs = append(inArgs, reflect.ValueOf(arg))		}		// 反射调用方法,传入参数		out := f.Call(inArgs)		// 解析遍历执行结果,放到一个数组中		outArgs := make([]interface{}, 0, len(out))		for _, o := range out {			outArgs = append(outArgs, o.Interface())		}		// 包装数据返回给客户端		respRPCData := RPCData{rpcData.Name, outArgs}		// 编码		respBytes, err := encode(respRPCData)		if err != nil {			fmt.Printf("encode err: %v", err)			return		}		// 使用RPC写出数据		err = srvSession.Write(respBytes)		if err != nil {			fmt.Printf("session write err:%v", err)			return		}	}}

rpc/client.go

package rpcimport (	"net"	"reflect")// 声明客户端type Client struct {	conn net.Conn}// 创建客户端对象func NewClient(conn net.Conn) *Client {	return &Client{conn: conn}}// 实现通用的RPC客户端// 绑定RPC使用的方法// 传入访问的函数名// 函数具体实现在Server端, Client只有函数原型// 使用MakeFunc() 完成原型到函数的调用// fPtr指向函数原型func (c *Client) callRPC(rpcName string, fPtr interface{}) {	// 通过反射,获取fPtr未初始化的函数原型	fn := reflect.ValueOf(fPtr).Elem()	// 另一个函数,是对第一个函数参数操作	f := func(args []reflect.Value) []reflect.Value {		// 处理输入的参数		inArgs := make([]interface{}, 0, len(args))		for _, arg := range args{			inArgs = append(inArgs, arg.Interface())		}		// 创建连接		cliSession := NewSession(c.conn)		// 编码数据		reqRPC := RPCData{Name: rpcName, Args: inArgs}		b, err := encode(reqRPC)		if err != nil {			panic(nil)		}		// 写出数据		err = cliSession.Write(b)		if err != nil {			panic(nil)		}		// 读响应数据		respBytes, err := cliSession.Read()		if err != nil {			panic(err)		}		// 解码数据		respRPC, err := decode(respBytes)		if err != nil {			panic(err)		}		// 处理服务端返回的数据		outArgs := make([]reflect.Value, 0, len(respRPC.Args))		for i, arg := range respRPC.Args {			// 必须进行nil转换			if arg == nil {				// 必须填充一个真正的类型,不能是nil				outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))				continue			}		}		return outArgs	}	v := reflect.MakeFunc(fn.Type(), f)	// 为函数fPtr赋值	fn.Set(v)}

rpc/simple_tpc_test.go

package rpcimport (	"encoding/gob"	"fmt"	"net"	"testing")// 用户查询// 用于测试的结构体type User struct {	Name string	Age int}// 用于测试查询用户的方法func queryUser(uid int) (User, error) {	user := make(map[int]User)	user[0] = User{"zs", 20}	user[1] = User{"ls", 21}	user[2] = User{"ww", 22}	// 模拟查询用户	if u, ok := user[uid]; ok {		return u, nil	}	return User{}, fmt.Errorf("id %d not in user db", uid)}func TestRPC(t *testing.T) {	// 需要对interface可能产生的类型进行注册	gob.Register(User{})	addr := "127.0.0.1:8080"	// 创建服务端	srv := NewServer(addr)	// 将方法注册到服务端	srv.Register("queryUser", queryUser)	// 服务端等待调用	go srv.Run()	// 客户端获取连接	conn , err := net.Dial("tcp", addr)	if err != nil {		t.Error(err)	}	// 创建客户端	cli := NewClient(conn)	// 声明函数原型	var query func(int) (User error)	cli.callRPC("queryUser", &query)	// 得到查询结果	u, err := query(1)	if err != nil {		t.Fatal(err)	}	fmt.Println(u)}

 

转载地址:https://liushilong.blog.csdn.net/article/details/114434266 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:【PHP版】顺丰下单API 、查询订单API、取消订单API
下一篇:Golang 编写RPC

发表评论

最新留言

逛到本站,mark一下
[***.202.152.39]2024年04月17日 13时13分20秒