如何为spark流创建mqtt接收器?

yhuiod9q  于 2021-05-29  发布在  Spark
关注(0)|答案(2)|浏览(509)

有一些示例说明如何为spark流创建mqtt源[1][2]。但是,我想创建一个mqtt接收器,在那里我可以发布结果,而不是使用 print() 方法。我试图创建一个mqttsink,但是我得到了 object not serializable 错误。然后我基于这个博客的代码,但我找不到方法 send 是我在 MqttSink 对象。

import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.{HashPartitioner, SparkConf}
import org.fusesource.mqtt.client.QoS
import org.sense.spark.util.{MqttSink, TaxiRideSource}

object TaxiRideCountCombineByKey {

  val mqttTopic: String = "spark-mqtt-sink"
  val qos: QoS = QoS.AT_LEAST_ONCE

  def main(args: Array[String]): Unit = {

    val outputMqtt: Boolean = if (args.length > 0 && args(0).equals("mqtt")) true else false

    // Create a local StreamingContext with two working thread and batch interval of 1 second.
    // The master requires 4 cores to prevent from a starvation scenario.
    val sparkConf = new SparkConf()
      .setAppName("TaxiRideCountCombineByKey")
      .setMaster("local[4]")
    val ssc = new StreamingContext(sparkConf, Seconds(1))

    val stream = ssc.receiverStream(new TaxiRideSource())
    val driverStream = stream.map(taxiRide => (taxiRide.driverId, 1))
    val countStream = driverStream.combineByKey(
      (v) => (v, 1), //createCombiner
      (acc: (Int, Int), v) => (acc._1 + v, acc._2 + 1), //mergeValue
      (acc1: (Int, Int), acc2: (Int, Int)) => (acc1._1 + acc2._1, acc1._2 + acc2._2), // mergeCombiners
      new HashPartitioner(3)
    )

    if (outputMqtt) {
      println("Use the command below to consume data:")
      println("mosquitto_sub -h 127.0.0.1 -p 1883 -t " + mqttTopic)

      val mqttSink = ssc.sparkContext.broadcast(MqttSink)
      countStream.foreachRDD { rdd =>
        rdd.foreach { message =>
          mqttSink.value.send(mqttTopic, message.toString()) // "send" method does not exist
        }
      }
    } else {
      countStream.print()
    }

    ssc.start() // Start the computation
    ssc.awaitTermination() // Wait for the computation to terminate
  }
}
import org.fusesource.mqtt.client.{FutureConnection, MQTT, QoS}

class MqttSink(createProducer: () => FutureConnection) extends Serializable {
  lazy val producer = createProducer()
  def send(topic: String, message: String): Unit = {
    producer.publish(topic, message.toString().getBytes, QoS.AT_LEAST_ONCE, false)
  }
}

object MqttSink {
  def apply(): MqttSink = {
    val f = () => {
      val mqtt = new MQTT()
      mqtt.setHost("localhost", 1883)
      val producer = mqtt.futureConnection()
      producer.connect().await()
      sys.addShutdownHook {
        producer.disconnect().await()
      }
      producer
    }
    new MqttSink(f)
  }
}
y4ekin9u

y4ekin9u1#

这是一个基于blog条目spark和kafka集成模式的工作示例。

package org.sense.spark.app

import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.{HashPartitioner, SparkConf}
import org.fusesource.mqtt.client.QoS
import org.sense.spark.util.{MqttSink, TaxiRideSource}

object TaxiRideCountCombineByKey {

  val mqttTopic: String = "spark-mqtt-sink"
  val qos: QoS = QoS.AT_LEAST_ONCE

  def main(args: Array[String]): Unit = {

    val outputMqtt: Boolean = if (args.length > 0 && args(0).equals("mqtt")) true else false

    // Create a local StreamingContext with two working thread and batch interval of 1 second.
    // The master requires 4 cores to prevent from a starvation scenario.
    val sparkConf = new SparkConf()
      .setAppName("TaxiRideCountCombineByKey")
      .setMaster("local[4]")
    val ssc = new StreamingContext(sparkConf, Seconds(1))

    val stream = ssc.receiverStream(new TaxiRideSource())
    val driverStream = stream.map(taxiRide => (taxiRide.driverId, 1))
    val countStream = driverStream.combineByKey(
      (v) => (v, 1), //createCombiner
      (acc: (Int, Int), v) => (acc._1 + v, acc._2 + 1), //mergeValue
      (acc1: (Int, Int), acc2: (Int, Int)) => (acc1._1 + acc2._1, acc1._2 + acc2._2), // mergeCombiners
      new HashPartitioner(3)
    )

    if (outputMqtt) {
      println("Use the command below to consume data:")
      println("mosquitto_sub -h 127.0.0.1 -p 1883 -t " + mqttTopic)

      val mqttSink = ssc.sparkContext.broadcast(MqttSink())
      countStream.foreachRDD { rdd =>
        rdd.foreach { message =>
          mqttSink.value.send(mqttTopic, message.toString()) // "send" method does not exist
        }
      }
    } else {
      countStream.print()
    }

    ssc.start() // Start the computation
    ssc.awaitTermination() // Wait for the computation to terminate
  }
}
package org.sense.spark.util

import org.fusesource.mqtt.client.{FutureConnection, MQTT, QoS}

class MqttSink(createProducer: () => FutureConnection) extends Serializable {

  lazy val producer = createProducer()

  def send(topic: String, message: String): Unit = {
    producer.publish(topic, message.toString().getBytes, QoS.AT_LEAST_ONCE, false)
  }
}

object MqttSink {
  def apply(): MqttSink = {
    val f = () => {
      val mqtt = new MQTT()
      mqtt.setHost("localhost", 1883)
      val producer = mqtt.futureConnection()
      producer.connect().await()
      sys.addShutdownHook {
        producer.disconnect().await()
      }
      producer
    }
    new MqttSink(f)
  }
}
package org.sense.spark.util

import java.io.{BufferedReader, FileInputStream, InputStreamReader}
import java.nio.charset.StandardCharsets
import java.util.Locale
import java.util.zip.GZIPInputStream

import org.apache.spark.storage._
import org.apache.spark.streaming.receiver._
import org.joda.time.DateTime
import org.joda.time.format.{DateTimeFormat, DateTimeFormatter}

case class TaxiRide(rideId: Long, isStart: Boolean, startTime: DateTime, endTime: DateTime,
                    startLon: Float, startLat: Float, endLon: Float, endLat: Float,
                    passengerCnt: Short, taxiId: Long, driverId: Long)

object TimeFormatter {
  val timeFormatter: DateTimeFormatter = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss").withLocale(Locale.US).withZoneUTC()
}

class TaxiRideSource extends Receiver[TaxiRide](StorageLevel.MEMORY_AND_DISK_2) {
  val dataFilePath = "/home/flink/nycTaxiRides.gz";
  var dataRateListener: DataRateListener = _

  /**
   * Start the thread that receives data over a connection
   */
  def onStart() {
    dataRateListener = new DataRateListener()
    dataRateListener.start()
    new Thread("TaxiRide Source") {
      override def run() {
        receive()
      }
    }.start()
  }

  def onStop() {}

  /**
   * Periodically generate a TaxiRide event and regulate the emission frequency
   */
  private def receive() {
    while (!isStopped()) {
      val gzipStream = new GZIPInputStream(new FileInputStream(dataFilePath))
      val reader: BufferedReader = new BufferedReader(new InputStreamReader(gzipStream, StandardCharsets.UTF_8))
      try {
        var line: String = null
        do {
          // start time before reading the line
          val startTime = System.nanoTime

          // read the line on the file and yield the object
          line = reader.readLine
          if (line != null) {
            val taxiRide: TaxiRide = getTaxiRideFromString(line)
            store(taxiRide)
          }

          // regulate frequency of the source
          dataRateListener.busySleep(startTime)
        } while (line != null)
      } finally {
        reader.close
      }
    }
  }

  def getTaxiRideFromString(line: String): TaxiRide = {
    // println(line)
    val tokens: Array[String] = line.split(",")
    if (tokens.length != 11) {
      throw new RuntimeException("Invalid record: " + line)
    }

    val rideId: Long = tokens(0).toLong
    val (isStart, startTime, endTime) = tokens(1) match {
      case "START" => (true, DateTime.parse(tokens(2), TimeFormatter.timeFormatter), DateTime.parse(tokens(3), TimeFormatter.timeFormatter))
      case "END" => (false, DateTime.parse(tokens(2), TimeFormatter.timeFormatter), DateTime.parse(tokens(3), TimeFormatter.timeFormatter))
      case _ => throw new RuntimeException("Invalid record: " + line)
    }
    val startLon: Float = if (tokens(4).length > 0) tokens(4).toFloat else 0.0f
    val startLat: Float = if (tokens(5).length > 0) tokens(5).toFloat else 0.0f
    val endLon: Float = if (tokens(6).length > 0) tokens(6).toFloat else 0.0f
    val endLat: Float = if (tokens(7).length > 0) tokens(7).toFloat else 0.0f
    val passengerCnt: Short = tokens(8).toShort
    val taxiId: Long = tokens(9).toLong
    val driverId: Long = tokens(10).toLong

    TaxiRide(rideId, isStart, startTime, endTime, startLon, startLat, endLon, endLat, passengerCnt, taxiId, driverId)
  }
}
hujrc8aj

hujrc8aj2#

作为替代方案,您还可以将结构流与apache bahir spark扩展一起用于mqtt。

完整示例

内部版本.sbt:

name := "MQTT_StructuredStreaming"
version := "0.1"

libraryDependencies += "org.apache.spark" % "spark-core_2.12" % "2.4.4"
libraryDependencies += "org.apache.spark" % "spark-sql_2.12" % "2.4.4"
libraryDependencies += "org.apache.spark" % "spark-streaming_2.12" % "2.4.4" % "provided"
libraryDependencies += "org.apache.bahir" % "spark-sql-streaming-mqtt_2.12" % "2.4.0"

main.scala公司

import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

object Main extends App {

  val brokerURL = "tcp://localhost:1883"
  val subTopicName = "/my/subscribe/topic"
  val pubTopicName = "/my/publish/topic"

  val spark: SparkSession = SparkSession
    .builder
    .appName("MQTT_StructStreaming")
    .master("local[*]")
    .config("spark.sql.streaming.checkpointLocation", "/my/sparkCheckpoint/dir")
    .getOrCreate

  spark.sparkContext.setLogLevel("ERROR")

  import spark.implicits._

  val lines: Dataset[String] = spark.readStream
    .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
    .option("topic", subTopicName)
    .option("clientId", "some-client-id")
    .option("persistence", "memory")
    .load(brokerURL)
    .selectExpr("CAST(payload AS STRING)").as[String]

  // Split the lines into words
  val words: Dataset[String] = lines.as[String].flatMap(_.split(";"))

  // Generate running word count
  val wordCounts: DataFrame = words.groupBy("value").count()

  // Start running the query that prints the running counts to the console
  val query: StreamingQuery = wordCounts.writeStream
    .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider")
    .outputMode("complete")
    .option("topic", pubTopicName)
    .option("brokerURL", brokerURL)
    .start

  query.awaitTermination()
}

相关问题