前面談過gRPC的SSL/TLS安全機制,發現設置過程比較複雜:好比證書籤名:須要服務端、客戶端兩頭都設置等。想一想實際上用JWT會更加便捷,並且更安全和功能強大,由於除JWT的加密簽名以外還能夠把私密的用戶信息放在JWT里加密後在服務端和客戶端之間傳遞。固然,最基本的是經過對JWT的驗證機制能夠控制客戶端對某些功能的使用權限。mongodb
經過JWT實現gRPC的函數調用權限管理原理其實很簡單:客戶端首先從服務端經過身份驗證獲取JWT,而後在調用服務函數時把這個JWT同時傳給服務端進行權限驗證。客戶端提交身份驗證請求返回JWT能夠用一個獨立的服務函數實現,以下面.proto文件裏的GetAuthToken:json
message PBPOSCredential {
string userid = 1;
string password = 2;
}
message PBPOSToken {
string jwt = 1;
}
service SendCommand {
rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {};
rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {};
rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {};
}
比較棘手的是如何把JWT從客戶端傳送至服務端,由於gRPC基本上騎劫了Request和Response。其中一個方法是經過Interceptor來截取Request的header即metadata。客戶端將JWT寫入metadata,服務端從metadata讀取JWT。安全
咱們先看看客戶端的Interceptor設置和使用:app
class AuthClientInterceptor(jwt: String) extends ClientInterceptor {
def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] =
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) {
override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = {
headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt)
super.start(responseListener, headers)
}
}
}
...
val unsafeChannel = NettyChannelBuilder
.forAddress("192.168.0.189",50051)
.negotiationType(NegotiationType.PLAINTEXT)
.build()
val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt))
val securedClient = SendCommandGrpc.blockingStub(securedChannel)
val resp = securedClient.singleResponse(PBPOSCommand())
身份驗證請求即JWT獲取是不須要Interceptor的,因此要用沒有Interceptor的unsafeChannel: jvm
//build connection channel
val unsafeChannel = NettyChannelBuilder
.forAddress("192.168.0.189",50051)
.negotiationType(NegotiationType.PLAINTEXT)
.build()
val authClient = SendCommandGrpc.blockingStub(unsafeChannel)
val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt
println(s"got jwt: $jwt")
JWT的構建和使用已經在前面的幾篇博文裏討論過了: async
package com.datatech.auth
import pdi.jwt._
import org.json4s.native.Json
import org.json4s._
import org.json4s.jackson.JsonMethods._
import pdi.jwt.algorithms._
import scala.util._
object AuthBase {
type UserInfo = Map[String, Any]
case class AuthBase(
algorithm: JwtAlgorithm = JwtAlgorithm.HMD5,
secret: String = "OpenSesame",
getUserInfo: (String,String) => Option[UserInfo] = null) {
ctx =>
def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo)
def withSecretKey(key: String): AuthBase = ctx.copy(secret = key)
def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f)
def authenticateToken(token: String): Option[String] =
algorithm match {
case algo: JwtAsymmetricAlgorithm =>
Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match {
case true => Some(token)
case _ => None
}
case _ =>
Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match {
case true => Some(token)
case _ => None
}
}
def getUserInfo(token: String): Option[UserInfo] = {
algorithm match {
case algo: JwtAsymmetricAlgorithm =>
Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match {
case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
case Failure(err) => None
}
case _ =>
Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match {
case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo])
case Failure(err) => None
}
}
}
def issueJwt(userinfo: UserInfo): String = {
val claims = JwtClaim() + Json(DefaultFormats).write(("userinfo", userinfo))
Jwt.encode(claims, secret, algorithm)
}
}
}
服務端Interceptor的構建和設置以下: ide
abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] {
protected val delegate: Future[Listener[Q]]
private val eventually = delegate.foreach _
override def onComplete(): Unit = eventually { _.onComplete() }
override def onCancel(): Unit = eventually { _.onCancel() }
override def onMessage(message: Q): Unit = eventually { _ onMessage message }
override def onHalfClose(): Unit = eventually { _.onHalfClose() }
override def onReady(): Unit = eventually { _.onReady() }
}
object Keys {
val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER)
val AUTH_CTX_KEY: Context.Key[String] = key("jwt")
}
class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor {
override def interceptCall[Q, R](
call: ServerCall[Q, R],
headers: Metadata,
next: ServerCallHandler[Q, R]
): Listener[Q] = {
val prevCtx = Context.current
val jwt = headers.get(Keys.AUTH_META_KEY)
println(s"!!!!!!!!!!! $jwt !!!!!!!!!!")
new FutureListener[Q] {
protected val delegate = Future {
val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt)
Contexts.interceptCall(nextCtx, call, headers, next)
}
}
}
}
trait gRPCServer {
def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = {
import actorSys.dispatcher
val server = NettyServerBuilder
.forPort(50051)
.addService(ServerInterceptors.intercept(service,
new AuthorizationInterceptor))
.build
.start
// make sure our server is stopped when jvm is shut down
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = {
server.shutdown()
server.awaitTermination()
}
})
}
}
注意:客戶端上傳的request-header只能在構建server時接觸到,在具體服務函數裏是沒法調用request-header的,但gRPC又一個結構Context能夠在兩個地方都能調用。因此,咱們能夠在構建server時把JWT從header搬到Context裏。不過,千萬注意這個Context的讀寫必須在同一個線程裏。在服務端的Interceptor裏咱們把JWT從metadata裏讀出而後寫入Context。在須要權限管理的服務函數裏再從Context裏讀取JWT進行驗證: 函數
override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = {
val jwt = AUTH_CTX_KEY.get
println(s"***********$jwt**************")
val optUserInfo = authenticator.getUserInfo(jwt)
val shopid = optUserInfo match {
case Some(m) => m("shopid")
case None => "invalid token!"
}
FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid"))
}
JWT的構建也是一個服務函數: ui
val authenticator = new AuthBase()
.withAlgorithm(JwtAlgorithm.HS256)
.withSecretKey("OpenSesame")
.withUserFunc(getValidUser)
override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = {
getValidUser(request.userid, request.password) match {
case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo)))
case None => FastFuture.successful(PBPOSToken("Invalid Token!"))
}
}
還須要一個模擬的身份驗證服務函數: google
package com.datatech.auth
object MockUserAuthService {
type UserInfo = Map[String,Any]
case class User(username: String, password: String, userInfo: UserInfo)
val validUsers = Seq(User("johnny", "p4ssw0rd",Map("shopid" -> "1101", "userid" -> "101"))
,User("tiger", "secret", Map("shopid" -> "1101" , "userid" -> "102")))
def getValidUser(userid: String, pswd: String): Option[UserInfo] =
validUsers.find(user => user.username == userid && user.password == pswd) match {
case Some(user) => Some(user.userInfo)
case _ => None
}
}
下面是本次示範的源代碼:
project/plugins.sbt
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.3.15") addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.21") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2") libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.9.0-M6"
build.sbt
name := "grpc-jwt" version := "0.1" version := "0.1" scalaVersion := "2.12.8" scalacOptions += "-Ypartial-unification" val akkaversion = "2.5.23" libraryDependencies := Seq( "com.typesafe.akka" %% "akka-cluster-metrics" % akkaversion, "com.typesafe.akka" %% "akka-cluster-sharding" % akkaversion, "com.typesafe.akka" %% "akka-persistence" % akkaversion, "com.lightbend.akka" %% "akka-stream-alpakka-cassandra" % "1.0.1", "org.mongodb.scala" %% "mongo-scala-driver" % "2.6.0", "com.lightbend.akka" %% "akka-stream-alpakka-mongodb" % "1.0.1", "com.typesafe.akka" %% "akka-persistence-query" % akkaversion, "com.typesafe.akka" %% "akka-persistence-cassandra" % "0.97", "com.datastax.cassandra" % "cassandra-driver-core" % "3.6.0", "com.datastax.cassandra" % "cassandra-driver-extras" % "3.6.0", "ch.qos.logback" % "logback-classic" % "1.2.3", "io.monix" %% "monix" % "3.0.0-RC2", "org.typelevel" %% "cats-core" % "2.0.0-M1", "io.grpc" % "grpc-netty" % scalapb.compiler.Version.grpcJavaVersion, "io.netty" % "netty-tcnative-boringssl-static" % "2.0.22.Final", "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf", "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion, "com.pauldijou" %% "jwt-core" % "3.0.1", "de.heikoseeberger" %% "akka-http-json4s" % "1.22.0", "org.json4s" %% "json4s-native" % "3.6.1", "com.typesafe.akka" %% "akka-http-spray-json" % "10.1.8", "org.json4s" %% "json4s-jackson" % "3.6.7", "org.json4s" %% "json4s-ext" % "3.6.7" ) // (optional) If you need scalapb/scalapb.proto or anything from // google/protobuf/*.proto //libraryDependencies += "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf" PB.targets in Compile := Seq( scalapb.gen() -> (sourceManaged in Compile).value ) enablePlugins(JavaAppPackaging)
main/protobuf/posmessages.proto
syntax = "proto3"; import "google/protobuf/wrappers.proto"; import "google/protobuf/any.proto"; import "scalapb/scalapb.proto"; option (scalapb.options) = { // use a custom Scala package name // package_name: "io.ontherocks.introgrpc.demo" // don't append file name to package flat_package: true // generate one Scala file for all messages (services still get their own file) single_file: true // add imports to generated file // useful when extending traits or using custom types // import: "io.ontherocks.hellogrpc.RockingMessage" // code to put at the top of generated file // works only with `single_file: true` //preamble: "sealed trait SomeSealedTrait" }; package com.datatech.pos.messages; message PBVchState { //單據狀態 string opr = 1; //收款員 int64 jseq = 2; //begin journal sequence for read-side replay int32 num = 3; //當前單號 int32 seq = 4; //當前序號 bool void = 5; //取消模式 bool refd = 6; //退款模式 bool susp = 7; //掛單 bool canc = 8; //廢單 bool due = 9; //當前餘額 string su = 10; //主管編號 string mbr = 11; //會員號 int32 mode = 12; //當前操做流程:0=logOff, 1=LogOn, 2=Payment } message PBTxnItem { //交易記錄 string txndate = 1; //交易日期 string txntime = 2; //錄入時間 string opr = 3; //操做員 int32 num = 4; //銷售單號 int32 seq = 5; //交易序號 int32 txntype = 6; //交易類型 int32 salestype = 7; //銷售類型 int32 qty = 8; //交易數量 int32 price = 9; //單價(分) int32 amount = 10; //碼洋(分) int32 disc = 11; //折扣率 (%) int32 dscamt = 12; //折扣額:負值 net實洋 = amount + dscamt string member = 13; //會員卡號 string code = 14; //編號(商品、卡號...) string acct = 15; //帳號 string dpt = 16; //部類 } message PBPOSResponse { int32 sts = 1; string msg = 2; PBVchState voucher = 3; repeated PBTxnItem txnitems = 4; } message PBPOSCommand { string commandname = 1; string delimitedparams = 2; } message PBPOSCredential { string userid = 1; string password = 2; } message PBPOSToken { string jwt = 1; } service SendCommand { rpc SingleResponse(PBPOSCommand) returns (PBPOSResponse) {}; rpc GetTxnItems(PBPOSCommand) returns (stream PBTxnItem) {}; rpc GetAuthToken(PBPOSCredential) returns (PBPOSToken) {}; }
gRPCServer.scala
package com.datatech.grpc.server import io.grpc.ServerServiceDefinition import io.grpc.netty.NettyServerBuilder import io.grpc.ServerInterceptors import scala.concurrent._ import io.grpc.Context import io.grpc.Contexts import io.grpc.ServerCall import io.grpc.ServerCallHandler import io.grpc.ServerInterceptor import io.grpc.Metadata import io.grpc.Metadata.Key.of import io.grpc.Context.key import io.grpc.ServerCall.Listener import akka.actor._ abstract class FutureListener[Q](implicit ec: ExecutionContext) extends Listener[Q] { protected val delegate: Future[Listener[Q]] private val eventually = delegate.foreach _ override def onComplete(): Unit = eventually { _.onComplete() } override def onCancel(): Unit = eventually { _.onCancel() } override def onMessage(message: Q): Unit = eventually { _ onMessage message } override def onHalfClose(): Unit = eventually { _.onHalfClose() } override def onReady(): Unit = eventually { _.onReady() } } object Keys { val AUTH_META_KEY: Metadata.Key[String] = of("jwt", Metadata.ASCII_STRING_MARSHALLER) val AUTH_CTX_KEY: Context.Key[String] = key("jwt") } class AuthorizationInterceptor(implicit ec: ExecutionContext) extends ServerInterceptor { override def interceptCall[Q, R]( call: ServerCall[Q, R], headers: Metadata, next: ServerCallHandler[Q, R] ): Listener[Q] = { val prevCtx = Context.current val jwt = headers.get(Keys.AUTH_META_KEY) println(s"!!!!!!!!!!! $jwt !!!!!!!!!!") new FutureListener[Q] { protected val delegate = Future { val nextCtx = prevCtx withValue (Keys.AUTH_CTX_KEY, jwt) Contexts.interceptCall(nextCtx, call, headers, next) } } } } trait gRPCServer { def runServer(service: ServerServiceDefinition)(implicit actorSys: ActorSystem): Unit = { import actorSys.dispatcher val server = NettyServerBuilder .forPort(50051) .addService(ServerInterceptors.intercept(service, new AuthorizationInterceptor)) .build .start // make sure our server is stopped when jvm is shut down Runtime.getRuntime.addShutdownHook(new Thread() { override def run(): Unit = { server.shutdown() server.awaitTermination() } }) } }
POSServices.scala
package com.datatech.pos.service import com.datatech.grpc.server.Keys._ import akka.http.scaladsl.util.FastFuture import com.datatech.pos.messages._ import com.datatech.grpc.server._ import com.datatech.auth.MockUserAuthService._ import scala.concurrent.Future import com.datatech.auth.AuthBase._ import pdi.jwt._ import akka.actor._ import io.grpc.stub.StreamObserver object POSServices extends gRPCServer { type UserInfo = Map[String, Any] class POSServices extends SendCommandGrpc.SendCommand { val authenticator = new AuthBase() .withAlgorithm(JwtAlgorithm.HS256) .withSecretKey("OpenSesame") .withUserFunc(getValidUser) override def getTxnItems(request: PBPOSCommand, responseObserver: StreamObserver[PBTxnItem]): Unit = ??? override def singleResponse(request: PBPOSCommand): Future[PBPOSResponse] = { val jwt = AUTH_CTX_KEY.get println(s"***********$jwt**************") val optUserInfo = authenticator.getUserInfo(jwt) val shopid = optUserInfo match { case Some(m) => m("shopid") case None => "invalid token!" } FastFuture.successful(PBPOSResponse(msg=s"shopid:$shopid")) } override def getAuthToken(request: PBPOSCredential): Future[PBPOSToken] = { getValidUser(request.userid, request.password) match { case Some(userinfo) => FastFuture.successful(PBPOSToken(authenticator.issueJwt(userinfo))) case None => FastFuture.successful(PBPOSToken("Invalid Token!")) } } } def main(args: Array[String]) = { implicit val system = ActorSystem("grpc-system") val svc = SendCommandGrpc.bindService(new POSServices, system.dispatcher) runServer(svc) } }
AuthBase.scala
package com.datatech.auth import pdi.jwt._ import org.json4s.native.Json import org.json4s._ import org.json4s.jackson.JsonMethods._ import pdi.jwt.algorithms._ import scala.util._ object AuthBase { type UserInfo = Map[String, Any] case class AuthBase( algorithm: JwtAlgorithm = JwtAlgorithm.HMD5, secret: String = "OpenSesame", getUserInfo: (String,String) => Option[UserInfo] = null) { ctx => def withAlgorithm(algo: JwtAlgorithm): AuthBase = ctx.copy(algorithm = algo) def withSecretKey(key: String): AuthBase = ctx.copy(secret = key) def withUserFunc(f: (String, String) => Option[UserInfo]): AuthBase = ctx.copy(getUserInfo = f) def authenticateToken(token: String): Option[String] = algorithm match { case algo: JwtAsymmetricAlgorithm => Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtAsymmetricAlgorithm]))) match { case true => Some(token) case _ => None } case _ => Jwt.isValid(token, secret, Seq((algorithm.asInstanceOf[JwtHmacAlgorithm]))) match { case true => Some(token) case _ => None } } def getUserInfo(token: String): Option[UserInfo] = { algorithm match { case algo: JwtAsymmetricAlgorithm => Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtAsymmetricAlgorithm])) match { case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo]) case Failure(err) => None } case _ => Jwt.decodeRawAll(token, secret, Seq(algorithm.asInstanceOf[JwtHmacAlgorithm])) match { case Success(parts) => Some(((parse(parts._2).asInstanceOf[JObject]) \ "userinfo").values.asInstanceOf[UserInfo]) case Failure(err) => None } } } def issueJwt(userinfo: UserInfo): String = { val claims = JwtClaim() + Json(DefaultFormats).write(("userinfo", userinfo)) Jwt.encode(claims, secret, algorithm) } } }
POSClient.scala
package com.datatech.pos.client import com.datatech.pos.messages.{PBPOSCommand, PBPOSCredential, SendCommandGrpc} import io.grpc.stub.StreamObserver import io.grpc.netty.{ NegotiationType, NettyChannelBuilder} import io.grpc.CallOptions import io.grpc.ClientCall import io.grpc.ClientInterceptor import io.grpc.ForwardingClientCall import io.grpc.Metadata import io.grpc.Metadata.Key import io.grpc.MethodDescriptor import io.grpc.ClientInterceptors object POSClient { class AuthClientInterceptor(jwt: String) extends ClientInterceptor { def interceptCall[ReqT, RespT](methodDescriptor: MethodDescriptor[ReqT, RespT], callOptions: CallOptions, channel: io.grpc.Channel): ClientCall[ReqT, RespT] = new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](channel.newCall(methodDescriptor, callOptions)) { override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata): Unit = { headers.put(Key.of("jwt", Metadata.ASCII_STRING_MARSHALLER), jwt) super.start(responseListener, headers) } } } def main(args: Array[String]): Unit = { //build connection channel val unsafeChannel = NettyChannelBuilder .forAddress("192.168.0.189",50051) .negotiationType(NegotiationType.PLAINTEXT) .build() val authClient = SendCommandGrpc.blockingStub(unsafeChannel) val jwt = authClient.getAuthToken(PBPOSCredential(userid="johnny",password="p4ssw0rd")).jwt println(s"got jwt: $jwt") val securedChannel = ClientInterceptors.intercept(unsafeChannel, new AuthClientInterceptor(jwt)) val securedClient = SendCommandGrpc.blockingStub(securedChannel) val resp = securedClient.singleResponse(PBPOSCommand()) println(s"secured response: $resp") // wait for async execution scala.io.StdIn.readLine() } }