博主头像
Kurfuerst

ワクワク

kitex源码阅读(一)

  • 我们可以从官方的示例kitex-example中的hello入手

服务初始化

main函数不出意外应该是这样的:

package main

import (
    "log"

    api "github.com/cloudwego/kitex-examples/hello/kitex_gen/api/hello"
)

func main() {
    svr := api.NewServer(new(HelloImpl))

    err := svr.Run()
    if err != nil {
        log.Println(err.Error())
    }
}

先创建Server对象,再调用svr.Run()运行程序。

让我们看一下api.NewServer创建Server对象的过程

// NewServer creates a server.Server with the given handler and options.
func NewServer(handler api.Hello, opts ...server.Option) server.Server {
    var options []server.Option

    options = append(options, opts...)
    options = append(options, server.WithCompatibleMiddlewareForUnary())

    svr := server.NewServer(options...)
    if err := svr.RegisterService(serviceInfo(), handler); err != nil {
        panic(err)
    }
    return svr
}

func RegisterService(svr server.Server, handler api.Hello, opts ...server.RegisterOption) error {
    return svr.RegisterService(serviceInfo(), handler, opts...)
}

这是kitex框架生成的代码,它先创建了一个server对象,然后调用RegisterService()在对象上注册相应的服务。

这个服务可以看作一个回调对象:当有请求时,kitex框架负责接受数据并解码,封装成请求体,再调用相应的服务与方法进行业务逻辑处理。

RegisterService()会调用s.svcs.addService()方法,将服务添加到server.svcs属性中。svcs是一个services对象,通过一个knownSvcMap map[string]*service存储相应的服务:

func (s *services) addService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, registerOpts *RegisterOptions) error {
    // unknown service
    // 处理未知服务,确保只能有一个未知服务
    if registerOpts.IsUnknownService {
        if s.unknownSvc != nil {
            return errors.New("multiple unknown services cannot be registered")
        }
        s.unknownSvc = &unknownService{svcs: map[string]*service{}, handler: handler}
        return nil
    }

    svc := newService(svcInfo, handler)
    // 处理回退服务,确保只能有一个回退服务
    if registerOpts.IsFallbackService {
        if s.fallbackSvc != nil {
            return fmt.Errorf("multiple fallback services cannot be registered. [%s] is already registered as a fallback service", s.fallbackSvc.svcInfo.ServiceName)
        }
        s.fallbackSvc = svc
    }
    // 处理普通服务
    // 防止重复注册
    if _, ok := s.knownSvcMap[svcInfo.ServiceName]; ok {
        return fmt.Errorf("service [%s] has already been registered", svcInfo.ServiceName)
    }
    s.knownSvcMap[svcInfo.ServiceName] = svc
    // 记录非回退服务
    if !registerOpts.IsFallbackService {
        s.nonFallbackSvcs = append(s.nonFallbackSvcs, svc)
    }
    return nil
}

启动服务

初始化完成后,就可以通过svr.Run()启动服务

// Run runs the server.
func (s *server) Run() (err error) {
    s.Lock()
    s.isRun = true
    s.Unlock()
    // 初始化中间件调用链 (buildInvokeChain)、调试服务、本地会话备份等
    s.init()
    if err = s.check(); err != nil {
        return err
    }
    // 获取远程通信相关的配置和监听地址
    svrCfg := s.opt.RemoteOpt
    addr := svrCfg.Address // should not be nil
    
    ...

    s.registerDebugInfo()
    // 添加一些内置的 BoundHandler,例如用于元信息处理、限流、连接数限制的处理器
    s.richRemoteOption()
    // 创建传输处理器
    transHdlr, err := s.newSvrTransHandler()
    if err != nil {
        return err
    }
    // 实例化 remotesvr.Server。负责网络监听、接收连接和读写数据,利用上文的TransHandler进行数据处理
    svr, err := remotesvr.NewServer(s.opt.RemoteOpt, transHdlr)
    if err != nil {
        return err
    }
    
    ...

    // 开始异步监听(非阻塞)
    errCh := svr.Start()
    
    ...

    // 进入阻塞状态,并将服务注册到服务发现中心
    if err = s.waitExit(errCh); err != nil {
        klog.Errorf("KITEX: received error and exit: error=%s", err.Error())
    }
    // 退出
    if e := s.Stop(); e != nil && err == nil {
        err = e
        klog.Errorf("KITEX: stop server error: error=%s", e.Error())
    }
    return
}

可以看到,启动的流程大致是以下几步:

  1. 初始化:执行 init(),构建中间件调用链。
  2. 配置检查:检查已注册的服务是否符合要求。
  3. 准备传输层:确定监听地址,并创建包含限流、元信息处理等逻辑的传输处理器 (TransHandler)。
  4. 创建底层 Server:实例化一个 remotesvr.Server,它封装了网络监听和数据收发的逻辑。
  5. 开始监听:调用 remotesvr.Server.Start(),服务器开始在网络端口上接收连接。
  6. 服务注册与等待:构建服务注册信息,然后阻塞等待退出信号。在阻塞期间,它会向服务发现中心注册自己,宣告服务可用。
  7. 退出:收到退出信号后,执行 Stop(),从服务中心反注册并关闭服务器。
中间件调用链
func (s *server) buildInvokeChain(ctx context.Context) {
    mws := s.buildMiddlewares(ctx)
    s.eps = endpoint.Chain(mws...)(s.unaryOrStreamEndpoint(ctx))
}

初始化过程中调用了buildInvokeChain(),构建中间件调用链。

func (s *server) buildMiddlewares(ctx context.Context) []endpoint.Middleware {
    ...

    // 3. universal middleware
    var mws []endpoint.Middleware
    // 分发中间件,判断是一元调用还是流式调用,并分发给相应的中间件链进行处理
    mws = append(mws, s.wrapStreamMiddleware())

    if s.opt.EnableContextTimeout {
        mws = append(mws, serverTimeoutMW)
    }
    // register server middlewares
    // 添加用户自定义的中间件
    for i := range s.opt.MWBs {
        if mw := s.opt.MWBs[i](ctx); mw != nil {
            mws = append(mws, mw)
        }
    }
    // register core middleware,
    // core middleware MUST be the last middleware
    mws = append(mws, s.buildCoreMiddleware())
    return mws
}

buildMiddlewares会先向中间件链中加入一个分发中间件wrapStreamMiddleware,负责识别请求类型(一元或流式)并分发到正确的调用链。随后添加用户自定义的中间件。最末端则是一个coreMiddleware,负责执行ACL与错误处理等框架级别的逻辑。

随后的Chain()会按下标逆序逐个包裹mws中的中间件,最内层则是s.unaryOrStreamEndpoint(ctx)创建的分发器,将请求交给对应的业务逻辑处理链。

这样我们就得到了一个完整的中间件调用链。

检查同名方法冲突

server在启动时会调用service的check方法,确保至少注册了一个服务且多个服务之间没有同名方法冲突

// 检查服务是否已注册,以及是否存在多个服务但未指定回退服务(fallback service)等配置问题
func (s *services) check(refuseTrafficWithoutServiceName bool) error {
    // 检查是否至少注册了一个服务(无论是普通服务 knownSvcMap 还是泛化调用服务 unknownSvc)
    if len(s.knownSvcMap) == 0 && s.unknownSvc == nil {
        return errors.New("run: no service. Use RegisterService to set one")
    }
    for _, svc := range s.knownSvcMap {
        // binary thrift generic v1无法与其他服务一起注册
        if svc.svcInfo.ServiceName == serviceinfo.GenericService {
            s.binaryThriftGenericV1SvcInfo = svc.svcInfo
            if len(s.knownSvcMap) > 1 { return error }
            if s.unknownSvc != nil { return error }
            return nil
        }
        // 同上,combine service合并服务也无法与其他服务一起注册
        if isCombineService, _ := svc.svcInfo.Extra[serviceinfo.CombineServiceKey].(bool); isCombineService {
            ...
        }
    }
    // 配置泛化服务(用于处理未明确注册的服务或方法的泛化调用)降级方案,会遍历所有已知的普通服务,注入泛化方法作为降级方案
    if s.unknownSvc != nil {
        for _, svc := range s.knownSvcMap {
            // 若该服务没有对应的方法,则会被转交给unknownSvb的handler处理
        }
    }
    
    ...
    
    // 检测同名方法冲突
    fallbackCheckingMap := make(map[string]int)
    // 遍历服务与方法
    for _, svc := range s.knownSvcMap {
        for method := range svc.svcInfo.Methods {
            // 若为回退服务,则该方法标记-1
            if svc == s.fallbackSvc {
                fallbackCheckingMap[method] = -1
            } else if num := fallbackCheckingMap[method]; num >= 0 {
                fallbackCheckingMap[method] = num + 1
            }
        }
    }
    // 若某方法在多个服务中存在则报错
    for methodName, serviceNum := range fallbackCheckingMap {
        if serviceNum > 1 {
            return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", methodName)
        }
    }
    return nil
}

首先,check会检查是否至少注册了一个服务(无论是普通服务 knownSvcMap 还是泛化调用服务 unknownSvc)。

处理特殊服务(互斥检查):检查是否存在旧版 Thrift 泛化服务 (binary thrift generic v1)与合并服务 (combine service),这两类服务不能与其他服务一起注册

随后配置泛化调用降级:如果存在 unknownSvc(用于处理未明确注册的服务或方法的泛化调用),则遍历所有已知的普通服务。它会为每个普通服务动态地注入一个作为降级方案。如果请求的某个服务没有对应的方法,请求不会立即失败,而是会被转交给 unknownSvc 的处理器来尝试进行泛化处理。

检查方法名冲突:这是最关键的一步。当服务器注册了多个服务时,不同的服务可能会有同名的方法(例如,两个服务都有一个叫 GetUser 的方法)。为了解决这种冲突,Kitex 允许指定一个 fallbackSvc(回退服务)。check函数会遍历所有已注册服务的所有方法并用一个 map 来统计每个方法在非回退服务中出现了多少次。

最后检查结果。如果发现任何一个方法在超过一个非回退服务中出现(即 serviceNum > 1),并且没有指定回退服务来解决这个冲突,函数就会返回一个错误指出哪个方法名冲突了。

监听

s.richRemoteOption()这个方法会添加各种绑定处理器,从而实现性能分析、流控等功能。随后实例化一个remotesvr.Server并开始异步监听。

// Start starts the server and return chan, the chan receive means server shutdown or err happen
func (s *server) Start() chan error {
    errCh := make(chan error, 1)
    ln, err := s.buildListener()
    if err != nil {
        errCh <- err
        return errCh
    }

    s.Lock()
    s.listener = ln
    s.Unlock()
    
    // 调用BootstrapServer处理连接
    gofunc.GoFunc(context.Background(), func() { errCh <- s.transSvr.BootstrapServer(ln) })
    return errCh
}

可以看到Start()方法是一个非阻塞方法,会立即返回一个error类型的channel。这个channel用于接受真正的服务循环s.transSvr.BootstrapServer(ln)的返回值,当从这个channel中接收到值时就意味着服务已经停止运行。

  • gofunc.GoFunc()是 Kitex 提供的一个 goroutine 包装器。它类似于 Go 语言原生的 go func() {},但增加了一个关键功能:panic 恢复。如果内部的函数执行时发生 panic,gofunc 会捕获它并将其作为错误处理,从而防止整个服务进程因未捕获的 panic 而崩溃,增强了服务的稳定性。

BootstrapServer方法有两个实现,先看一下标准库gonet的实现:

func (ts *transServer) BootstrapServer(ln net.Listener) error {
    ...

    for {
        conn, err := ln.Accept()
        if err != nil {
            ...
            
            klog.Errorf("KITEX: BootstrapServer accept failed, err=%s", err.Error())
            return err
        }
        go ts.serveConn(context.Background(), conn)
    }
}

conn, err := ln.Accept()是循环的核心。Accept() 是一个阻塞调用,它会一直等待,直到有新的客户端连接请求到达。当一个连接被接受后,它会返回代表该连接的 net.Conn 对象和一个 error。如果errornil,则连接正常,会启动一个新的goroutine处理这个连接。

func (ts *transServer) serveConn(ctx context.Context, conn net.Conn) (err error) {
    // 兜底 panic恢复,防止单个连接的崩溃导致整个server进程推出
    defer transRecover(ctx, conn, "serveConn")

    ...

    // 初始化
    ctx, err = ts.transHdlr.OnActive(ctx, bc)
    if err != nil {
        klog.CtxErrorf(ctx, "KITEX: OnActive error=%s", err)
        return err
    }
    for {
        // 设置连接空闲超时
        ts.refreshIdleDeadline(bc)
        _, err = bc.r.Peek(1)
        if err != nil {
            return err
        }
        // 设置请求读取超时,确保服务器不会因为缓慢或恶意的客户端一直不发送完整数据而被永久阻塞
        ts.refreshReadDeadline(rpcinfo.GetRPCInfo(ctx), bc)
        // FIXME: for gRPC transHandler, OnRead should execute only once.
        // 开始读取
        err = ts.transHdlr.OnRead(ctx, bc)
        if err != nil {
            return err
        }
    }
}

serveConn()的逻辑大致是先初始化连接,成功后开始循环进行数据读取。transServer设置了连接空闲超时与请求读取超时,防止连接空闲时间过长以及读取超时阻塞导致的资源浪费与泄露。读取循环会一直运行,直到出现错误或数据读取完毕(连接空闲超时)

该方法和go标准库中的ListenAndServe方法很像,都是让一个goroutine去等待连接建立,再交给其他goroutine接收并处理数据

服务注册

开启监听后,Run()方法会调用waitExit()方法阻塞主线程,直到退出

func (s *server) waitExit(errCh chan error) error {
    exitSignal := s.opt.ExitSignal()

    // service may not be available as soon as startup.
    delayRegister := time.After(1 * time.Second)
    for {
        select {
        case err := <-exitSignal:
            return err
        case err := <-errCh:
            return err
        case <-delayRegister:
            s.Lock()
            if err := s.opt.Registry.Register(s.opt.RegistryInfo); err != nil {
                s.Unlock()
                return err
            }
            s.Unlock()
        }
    }
}

它主要监听以下三个信号:

  • 外部退出信号 (exitSignal): 用于接收操作系统的信号(如 SIGINT, SIGTERM),当这个 channel 收到信号时,方法会返回收到的 error,从而触发服务停止流程。
  • 内部服务器错误 (errCh): 这个 channel 是从底层的网络服务器(remotesvr.Server)传递过来的。如果在监听、接受连接或处理数据时发生不可恢复的错误,该错误会被发送到这个 channel。waitExit 接收到错误后会立即返回,同样触发服务停止。
  • 延迟注册定时器 (delayRegister): 这是一个通过 time.After(1 * time.Second) 创建的定时器。它会延迟1秒钟再执行服务注册。这样做的目的是为了确保服务器已经完全准备好处理请求,然后再将自己注册到服务发现中心(如 Nacos、etcd),避免客户端过早地发现一个尚未完全就绪的服务实例。
退出

Stop()方法用于在服务结束后进行收尾工作

// Stop stops the server gracefully.
func (s *server) Stop() (err error) {
    s.stopped.Do(func() {
        s.Lock()
        defer s.Unlock()

        // 执行关闭相关钩子函数
        muShutdownHooks.Lock()
        for i := range onShutdown {
            onShutdown[i]()
        }
        muShutdownHooks.Unlock()

        // 服务注册中心取消注册
        if s.opt.RegistryInfo != nil {
            err = s.opt.Registry.Deregister(s.opt.RegistryInfo)
            s.opt.RegistryInfo = nil
        }
        
        // 停止监听
        if s.svr != nil {
            if e := s.svr.Stop(); e != nil {
                err = e
            }
            s.svr = nil
        }
    })
    return
}

server结构体的定义中,stopped的类型是sync.Once

type server struct {
    // ...
    stopped sync.Once
    // ...
}

它使得 Stop()方法是幂等且线程安全的。无论你调用 Stop() 一次还是多次,效果都是相同的——服务器只会被关闭一次。后续的调用会立即返回,不会产生任何副作用,也不会报错。并且sync.Once 可以保证即使有多个协程并发调用 Stop(),包裹在 Do 方法里的实际关闭代码(服务反注册、停止网络监听等)也只会被其中一个协程执行一次,从而避免了竞态条件和重复关闭导致的混乱。

kitex源码阅读(一)
http://blog.kurfuerst.online/index.php/archives/35/
本文作者 Großer Kurfürst
发布时间 2025-10-10
许可协议 CC BY-NC-SA 4.0
发表新评论