为了账号安全,请及时绑定邮箱和手机立即绑定

手把手教你基于Netty实现一个基础的RPC框架(通俗易懂)

标签:
Java

在前面的内容中,我们已经由浅入深的理解了Netty的基础知识和实现原理,相信大家已经对Netty有了一个较为全面的理解。那么接下来,我们通过一个手写RPC通信的实战案例来带大家了解Netty的实际应用。

为什么要选择RPC来作为实战呢?因为Netty本身就是解决通信问题,而在实际应用中,RPC协议框架是我们接触得最多的一种,所以这个实战能让大家了解到Netty实际应用之外,还能理解RPC的底层原理。

什么是RPC

RPC全称为(Remote Procedure Call),是一种通过网络从远程计算机程序上请求服务,而不需要了解底层网络技术的协议,简单理解就是让开发者能够像调用本地服务一样调用远程服务。

既然是协议,那么它必然有协议的规范,如图6-1所示。

为了达到“让开发者能够像调用本地服务那样调用远程服务”的目的,RPC协议需像图6-1那样实现远程交互。

  • 客户端调用远程服务时,必须要通过本地动态代理模块来屏蔽网络通信的细节,所以动态代理模块需要负责将请求参数、方法等数据组装成数据包发送到目标服务器
  • 这个数据包在发送时,还需要遵循约定的消息协议以及序列化协议,最终转化为二进制数据流传输
  • 服务端收到数据包后,先按照约定的消息协议解码,得到请求信息。
  • 服务端再根据请求信息路由调用到目标服务,获得结果并返回给客户端。

1567677351249

图6-1

业内主流的RPC框架

凡是满足RPC协议的框架,我们成为RPC框架,在实际开发中,我们可以使用开源且相对成熟的RPC框架解决微服务架构下的远程通信问题,常见的rpc框架:

  1. Thrift:thrift是一个软件框架,用来进行可扩展且跨语言的服务的开发。它结合了功能强大的软件堆栈和代码生成引擎,以构建在 C++, Java, Python, PHP, Ruby, Erlang, Perl, Haskell, C#, Cocoa, JavaScript, Node.js, Smalltalk, and OCaml 这些编程语言间无缝结合的、高效的服务。
  2. Dubbo:Dubbo是一个分布式服务框架,以及SOA治理方案。其功能主要包括:高性能NIO通讯及多协议集成,服务动态寻址与路由,软负载均衡与容错,依赖分析与降级等。 Dubbo是阿里巴巴内部的SOA服务化治理方案的核心框架,Dubbo自2011年开源后,已被许多非阿里系公司使用。

手写RPC注意要点

基于上文中对于RPC协议的理解,如果我们自己去实现,需要考虑哪些技术呢? 其实基于图6-1的整个流程应该有一个大概的理解。

  • 通信协议,RPC框架对性能的要求非常高,所以通信协议应该是越简单越好,这样可以减少编解码带来的性能损耗,大部分主流的RPC框架会直接选择TCP、HTTP协议。
  • 序列化和反序列化,数据要进行网络传输,需要对数据进行序列化和反序列化,前面我们说过,所谓的序列化和反序列化是不把对象转化成二进制流以及将二进制流转化成对象的过程。在序列化框架选择上,我们一般会选择高效且通用的算法,比如FastJson、Protobuf、Hessian等。这些序列化技术都要比原生的序列化操作更加高效,压缩比也较高。
  • 动态代理, 客户端调用远程服务时,需要通过动态代理来屏蔽网络通信细节。而动态代理又是在运行过程中生成的,所以动态代理类的生成速度、字节码大小都会影响到RPC整体框架的性能和资源消耗。常见的动态代理技术: Javassist、Cglib、JDK的动态代理等。

基于Netty手写实现RPC

理解了RPC协议后,我们基于Netty来实现一个RPC通信框架。

代码详见附件 netty-rpc-example

image-20210907221358022

图6-2 项目模块组成

需要引入的jar包:

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter</artifactId>
</dependency>
<dependency>
    <groupId>org.projectlombok</groupId>
    <artifactId>lombok</artifactId>
</dependency>
<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>fastjson</artifactId>
    <version>1.2.72</version>
</dependency>
<dependency>
    <groupId>io.netty</groupId>
    <artifactId>netty-all</artifactId>
</dependency>

模块依赖关系:

  • provider依赖 netty-rpc-protocol和netty-rpc-api

  • cosumer依赖 netty-rpc-protocol和netty-rpc-api

netty-rpc-api模块

image-20210907223045613

图6-3 netty-rpc-api模块组成

IUserService

public interface IUserService {

    String saveUser(String name);
}

netty-rpc-provider模块

image-20210907223111784

图6-4 netty-rpc-provider模块组成

UserServiceImpl

@Service
@Slf4j
public class UserServiceImpl implements IUserService {
    @Override
    public String saveUser(String name) {
        log.info("begin saveUser:"+name);
        return "Save User Success!";
    }
}

NettyRpcProviderMain

注意,在当前步骤中,描述了case的部分,暂时先不用加,后续再加上

@ComponentScan(basePackages = {"com.example.spring","com.example.service"})  //case1(后续再加上)
@SpringBootApplication
public class NettyRpcProviderMain {

    public static void main(String[] args) throws Exception {
        SpringApplication.run(NettyRpcProviderMain.class, args);
        new NettyServer("127.0.0.1",8080).startNettyServer();   //case2(后续再加上)
    }
}

netty-rpc-protocol

开始写通信协议模块,这个模块主要做几个事情

  • 定义消息协议
  • 定义序列化反序列化方法
  • 建立netty通信

图6-5

定义消息协议

之前我们讲过自定义消息协议,我们在这里可以按照下面这个协议格式来定义好。

    /*
    +----------------------------------------------+
    | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte  |
    +----------------------------------------------+
    | 消息 ID 8byte     |      数据长度 4byte       |
    +----------------------------------------------+
    */

Header

@AllArgsConstructor
@Data
public class Header implements Serializable {
    /*
    +----------------------------------------------+
    | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte  |
    +----------------------------------------------+
    | 消息 ID 8byte     |      数据长度 4byte       |
    +----------------------------------------------+
    */
    private short magic; //魔数-用来验证报文的身份(2个字节)
    private byte serialType; //序列化类型(1个字节)
    private byte reqType; //操作类型(1个字节)
    private long requestId; //请求id(8个字节)
    private int length; //数据长度(4个字节)

}

RpcRequest

@Data
public class RpcRequest implements Serializable {
    private String className;
    private String methodName;
    private Object[] params;
    private Class<?>[] parameterTypes;
}

RpcResponse

@Data
public class RpcResponse implements Serializable {

    private Object data;
    private String msg;
}

RpcProtocol

@Data
public class RpcProtocol<T> implements Serializable {
    private Header header;
    private T content;
}

定义相关常量

上述消息协议定义中,涉及到几个枚举相关的类,定义如下

ReqType

消息类型

public enum ReqType {

    REQUEST((byte)1),
    RESPONSE((byte)2),
    HEARTBEAT((byte)3);

    private byte code;

    private ReqType(byte code) {
        this.code=code;
    }

    public byte code(){
        return this.code;
    }
    public static ReqType findByCode(int code) {
        for (ReqType msgType : ReqType.values()) {
            if (msgType.code() == code) {
                return msgType;
            }
        }
        return null;
    }
}

SerialType

序列化类型

public enum SerialType {

    JSON_SERIAL((byte)0),
    JAVA_SERIAL((byte)1);

    private byte code;

    SerialType(byte code) {
        this.code=code;
    }

    public byte code(){
        return this.code;
    }
}

RpcConstant

public class RpcConstant {
    //header部分的总字节数
    public final static int HEAD_TOTAL_LEN=16;
    //魔数
    public final static short MAGIC=0xca;
}

定义序列化相关实现

这里演示两种,一种是JSON方式,另一种是Java原生的方式

ISerializer

public interface ISerializer {

    <T> byte[] serialize(T obj);

    <T> T deserialize(byte[] data,Class<T> clazz);

    byte getType();
}

JavaSerializer

public class JavaSerializer implements ISerializer{

    @Override
    public <T> byte[] serialize(T obj) {
        ByteArrayOutputStream byteArrayOutputStream=
                new ByteArrayOutputStream();
        try {
            ObjectOutputStream outputStream=
                    new ObjectOutputStream(byteArrayOutputStream);

            outputStream.writeObject(obj);

            return  byteArrayOutputStream.toByteArray();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return new byte[0];
    }

    @Override
    public <T> T deserialize(byte[] data, Class<T> clazz) {
        ByteArrayInputStream byteArrayInputStream=new ByteArrayInputStream(data);
        try {
            ObjectInputStream objectInputStream=
                    new ObjectInputStream(byteArrayInputStream);

            return (T) objectInputStream.readObject();

        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        return null;
    }

    @Override
    public byte getType() {
        return SerialType.JAVA_SERIAL.code();
    }
}

JsonSerializer

public class JsonSerializer implements ISerializer{
    @Override
    public <T> byte[] serialize(T obj) {
        return JSON.toJSONString(obj).getBytes();
    }

    @Override
    public <T> T deserialize(byte[] data, Class<T> clazz) {
        return JSON.parseObject(new String(data),clazz);
    }

    @Override
    public byte getType() {
        return SerialType.JSON_SERIAL.code();
    }
}

SerializerManager

实现对序列化机制的管理

public class SerializerManager {

    private final static ConcurrentHashMap<Byte, ISerializer> serializers=new ConcurrentHashMap<Byte, ISerializer>();

    static {
        ISerializer jsonSerializer=new JsonSerializer();
        ISerializer javaSerializer=new JavaSerializer();
        serializers.put(jsonSerializer.getType(),jsonSerializer);
        serializers.put(javaSerializer.getType(),javaSerializer);
    }

    public static ISerializer getSerializer(byte key){
        ISerializer serializer=serializers.get(key);
        if(serializer==null){
            return new JavaSerializer();
        }
        return serializer;
    }
}

定义编码和解码实现

由于自定义了消息协议,所以 需要自己实现编码和解码,代码如下

RpcDecoder

@Slf4j
public class RpcDecoder extends ByteToMessageDecoder {


    /*
    +----------------------------------------------+
    | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte  |
    +----------------------------------------------+
    | 消息 ID 8byte     |      数据长度 4byte       |
    +----------------------------------------------+
    */
    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        log.info("==========begin RpcDecoder ==============");
        if(in.readableBytes()< RpcConstant.HEAD_TOTAL_LEN){
            //消息长度不够,不需要解析
            return;
        }
        in.markReaderIndex();//标记一个读取数据的索引,后续用来重置。
        short magic=in.readShort(); //读取magic
        if(magic!=RpcConstant.MAGIC){
            throw new IllegalArgumentException("Illegal request parameter 'magic',"+magic);
        }
        byte serialType=in.readByte(); //读取序列化算法类型
        byte reqType=in.readByte(); //请求类型
        long requestId=in.readLong(); //请求消息id
        int dataLength=in.readInt(); //请求数据长度
        //可读区域的字节数小于实际数据长度
        if(in.readableBytes()<dataLength){
            in.resetReaderIndex();
            return;
        }
        //读取消息内容
        byte[] content=new byte[dataLength];
        in.readBytes(content);

        //构建header头信息
        Header header=new Header(magic,serialType,reqType,requestId,dataLength);
        ISerializer serializer=SerializerManager.getSerializer(serialType);
        ReqType rt=ReqType.findByCode(reqType);
        switch(rt){
            case REQUEST:
                RpcRequest request=serializer.deserialize(content, RpcRequest.class);
                RpcProtocol<RpcRequest> reqProtocol=new RpcProtocol<>();
                reqProtocol.setHeader(header);
                reqProtocol.setContent(request);
                out.add(reqProtocol);
                break;
            case RESPONSE:
                RpcResponse response=serializer.deserialize(content,RpcResponse.class);
                RpcProtocol<RpcResponse> resProtocol=new RpcProtocol<>();
                resProtocol.setHeader(header);
                resProtocol.setContent(response);
                out.add(resProtocol);
                break;
            case HEARTBEAT:
                break;
            default:
                break;
        }

    }
}

RpcEncoder

@Slf4j
public class RpcEncoder extends MessageToByteEncoder<RpcProtocol<Object>> {

    /*
    +----------------------------------------------+
    | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte  |
    +----------------------------------------------+
    | 消息 ID 8byte     |      数据长度 4byte       |
    +----------------------------------------------+
    */
    @Override
    protected void encode(ChannelHandlerContext ctx, RpcProtocol<Object> msg, ByteBuf out) throws Exception {
        log.info("=============begin RpcEncoder============");
        Header header=msg.getHeader();
        out.writeShort(header.getMagic()); //写入魔数
        out.writeByte(header.getSerialType()); //写入序列化类型
        out.writeByte(header.getReqType());//写入请求类型
        out.writeLong(header.getRequestId()); //写入请求id
        ISerializer serializer= SerializerManager.getSerializer(header.getSerialType());
        byte[] data=serializer.serialize(msg.getContent()); //序列化
        header.setLength(data.length);
        out.writeInt(data.length); //写入消息长度
        out.writeBytes(data);
    }
}

NettyServer

实现NettyServer构建。

@Slf4j
public class NettyServer{
    private String serverAddress; //地址
    private int serverPort; //端口

    public NettyServer(String serverAddress, int serverPort) {
        this.serverAddress = serverAddress;
        this.serverPort = serverPort;
    }

    public void startNettyServer() throws Exception {
        log.info("begin start Netty Server");
        EventLoopGroup bossGroup=new NioEventLoopGroup();
        EventLoopGroup workGroup=new NioEventLoopGroup();
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(bossGroup, workGroup)
                .channel(NioServerSocketChannel.class)
                .childHandler(new RpcServerInitializer());
            ChannelFuture channelFuture = bootstrap.bind(this.serverAddress, this.serverPort).sync();
            log.info("Server started Success on Port:{}", this.serverPort);
            channelFuture.channel().closeFuture().sync();
        }catch (Exception e){
            log.error("Rpc Server Exception",e);
        }finally {
            workGroup.shutdownGracefully();
            bossGroup.shutdownGracefully();
        }
    }
}

RpcServerInitializer

public class RpcServerInitializer extends ChannelInitializer<SocketChannel> {
    @Override
    protected void initChannel(SocketChannel ch) throws Exception {
        ch.pipeline()
            .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE,12,4,0,0))
            .addLast(new RpcDecoder())
            .addLast(new RpcEncoder())
            .addLast(new RpcServerHandler());
    }
}

RpcServerHandler

public class RpcServerHandler extends SimpleChannelInboundHandler<RpcProtocol<RpcRequest>> {

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcProtocol<RpcRequest> msg) throws Exception {
        RpcProtocol resProtocol=new RpcProtocol<>();
        Header header=msg.getHeader();
        header.setReqType(ReqType.RESPONSE.code());
        Object result=invoke(msg.getContent());
        resProtocol.setHeader(header);
        RpcResponse response=new RpcResponse();
        response.setData(result);
        response.setMsg("success");
        resProtocol.setContent(response);

        ctx.writeAndFlush(resProtocol);
    }

    private Object invoke(RpcRequest request){
        try {
            Class<?> clazz=Class.forName(request.getClassName());
            Object bean= SpringBeansManager.getBean(clazz); //获取实例对象(CASE)
            Method declaredMethod=clazz.getDeclaredMethod(request.getMethodName(),request.getParameterTypes());
            return declaredMethod.invoke(bean,request.getParams());
        } catch (ClassNotFoundException | NoSuchMethodException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
        return null;
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        super.exceptionCaught(ctx, cause);
    }
}

SpringBeansManager

@Component
public class SpringBeansManager implements ApplicationContextAware {
    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        SpringBeansManager.applicationContext=applicationContext;
    }

    public static <T> T getBean(Class<T> clazz){
        return applicationContext.getBean(clazz);
    }
}

需要注意,这个类的构建好之后,需要在netty-rpc-provider模块的main方法中增加compone-scan进行扫描

@ComponentScan(basePackages = {"com.example.spring","com.example.service"})  //修改这里
@SpringBootApplication
public class NettyRpcProviderMain {

    public static void main(String[] args) throws Exception {
        SpringApplication.run(NettyRpcProviderMain.class, args);
        new NettyServer("127.0.0.1",8080).startNettyServer();  // 修改这里
    }
}

netty-rpc-consumer

接下来开始实现消费端

RpcClientProxy

public class RpcClientProxy {
    
    public <T> T clientProxy(final Class<T> interfaceCls,final String host,final int port){
        return (T) Proxy.newProxyInstance
                (interfaceCls.getClassLoader(),
                        new Class<?>[]{interfaceCls},
                        new RpcInvokerProxy(host,port));
    }
}

RpcInvokerProxy

@Slf4j
public class RpcInvokerProxy implements InvocationHandler {

    private String serviceAddress;
    private int servicePort;

    public RpcInvokerProxy(String serviceAddress, int servicePort) {
        this.serviceAddress = serviceAddress;
        this.servicePort = servicePort;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        log.info("begin invoke target server");
        //组装参数
        RpcProtocol<RpcRequest> protocol=new RpcProtocol<>();
        long requestId= RequestHolder.REQUEST_ID.incrementAndGet();
        Header header=new Header(RpcConstant.MAGIC, SerialType.JSON_SERIAL.code(), ReqType.REQUEST.code(),requestId,0);
        protocol.setHeader(header);
        RpcRequest request=new RpcRequest();
        request.setClassName(method.getDeclaringClass().getName());
        request.setMethodName(method.getName());
        request.setParameterTypes(method.getParameterTypes());
        request.setParams(args);
        protocol.setContent(request);
        //发送请求
        NettyClient nettyClient=new NettyClient(serviceAddress,servicePort);
        //构建异步数据处理
        RpcFuture<RpcResponse> future=new RpcFuture<>(new DefaultPromise<>(new DefaultEventLoop()));
        RequestHolder.REQUEST_MAP.put(requestId,future);
        nettyClient.sendRequest(protocol);
        return future.getPromise().get().getData();
    }
}

定义客户端连接

在netty-rpc-protocol这个模块的protocol包路径下,创建NettyClient

@Slf4j
public class NettyClient {
    private final Bootstrap bootstrap;
    private final EventLoopGroup eventLoopGroup=new NioEventLoopGroup();
    private String serviceAddress;
    private int servicePort;
    public NettyClient(String serviceAddress,int servicePort){
        log.info("begin init NettyClient");
        bootstrap=new Bootstrap();
        bootstrap.group(eventLoopGroup)
                .channel(NioSocketChannel.class)
                .handler(new RpcClientInitializer());
        this.serviceAddress=serviceAddress;
        this.servicePort=servicePort;
    }

    public void sendRequest(RpcProtocol<RpcRequest> protocol) throws InterruptedException {
        ChannelFuture future=bootstrap.connect(this.serviceAddress,this.servicePort).sync();
        future.addListener(listener->{
            if(future.isSuccess()){
                log.info("connect rpc server {} success.",this.serviceAddress);
            }else{
                log.error("connect rpc server {} failed .",this.serviceAddress);
                future.cause().printStackTrace();
                eventLoopGroup.shutdownGracefully();
            }
        });
        log.info("begin transfer data");
        future.channel().writeAndFlush(protocol);
    }
}

RpcClientInitializer

@Slf4j
public class RpcClientInitializer extends ChannelInitializer<SocketChannel> {
    @Override
    protected void initChannel(SocketChannel ch) throws Exception {
        log.info("begin initChannel");
        ch.pipeline()
                .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE,12,4,0,0))
                .addLast(new LoggingHandler())
                .addLast(new RpcEncoder())
                .addLast(new RpcDecoder())
                .addLast(new RpcClientHandler());
    }
}

RpcClientHandler

需要注意,Netty的通信过程是基于入站出站分离的,所以在获取结果时,我们需要借助一个Future对象来完成。

@Slf4j
public class RpcClientHandler extends SimpleChannelInboundHandler<RpcProtocol<RpcResponse>> {

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcProtocol<RpcResponse> msg) throws Exception {
        log.info("receive rpc server result");
        long requestId=msg.getHeader().getRequestId();
        RpcFuture<RpcResponse> future=RequestHolder.REQUEST_MAP.remove(requestId);
        future.getPromise().setSuccess(msg.getContent()); //返回结果
    }
}

Future的实现

在netty-rpc-protocol模块中添加rpcFuture实现

RpcFuture

@Data
public class RpcFuture<T> {
    //Promise是可写的 Future, Future自身并没有写操作相关的接口,
    // Netty通过 Promise对 Future进行扩展,用于设置IO操作的结果
    private Promise<T> promise;

    public RpcFuture(Promise<T> promise) {
        this.promise = promise;
    }
}

RequestHolder

保存requestid和future的对应结果

public class RequestHolder {

    public static final AtomicLong REQUEST_ID=new AtomicLong();

    public static final Map<Long,RpcFuture> REQUEST_MAP=new ConcurrentHashMap<>();
}

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Mic带你学架构
如果本篇文章对您有帮助,还请帮忙点个关注和赞,您的坚持是我不断创作的动力。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消