@@ -2,12 +2,12 @@ package io.weaviate.spark
22
33import com .google .gson .reflect .TypeToken
44import com .google .gson .{Gson , JsonSyntaxException }
5+ import io .weaviate .client6 .v1 .api .collections .{WeaviateObject , DataType => WeaviateDataType }
56import org .apache .spark .internal .Logging
67import org .apache .spark .sql .catalyst .InternalRow
78import org .apache .spark .sql .connector .write .{DataWriter , WriterCommitMessage }
89import org .apache .spark .sql .types ._
9- import io .weaviate .client .v1 .data .model .WeaviateObject
10- import io .weaviate .client .v1 .schema .model .WeaviateClass
10+ import io .weaviate .client6 .v1 .api .collections .{CollectionConfig , Vectors }
1111import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData }
1212
1313import java .util .{Map => JavaMap }
@@ -20,54 +20,53 @@ case class WeaviateCommitMessage(msg: String) extends WriterCommitMessage
2020
2121case class WeaviateDataWriter (weaviateOptions : WeaviateOptions , schema : StructType )
2222 extends DataWriter [InternalRow ] with Serializable with Logging {
23- var batch = mutable.Map [String , WeaviateObject ]()
24- private val weaviateClass = weaviateOptions.getWeaviateClass ()
23+ var batch = mutable.Map [String , WeaviateObject [ JavaMap [ String , Object ]] ]()
24+ private lazy val weaviateClass = weaviateOptions.getCollectionConfig ()
2525
2626 override def write (record : InternalRow ): Unit = {
2727 val weaviateObject = buildWeaviateObject(record, weaviateClass)
28- batch += (weaviateObject.getId -> weaviateObject)
28+ batch += (weaviateObject.uuid -> weaviateObject)
2929
3030 if (batch.size >= weaviateOptions.batchSize) writeBatch()
3131 }
3232
3333 def writeBatch (retries : Int = weaviateOptions.retries): Unit = {
34- if (batch.size == 0 ) return
34+ if (batch.isEmpty ) return
3535
36- val consistencyLevel = weaviateOptions.consistencyLevel
3736 val client = weaviateOptions.getClient()
3837
39- val results = if (consistencyLevel != " " ) {
40- logInfo( s " Writing using consistency level: ${consistencyLevel} " )
41- client.batch().objectsBatcher().withObjects(batch.values.toList : _* ).withConsistencyLevel(consistencyLevel).run( )
42- } else {
43- client.batch().objectsBatcher().withObjects(batch.values.toList : _* ).run()
44- }
38+ val collection = client.collections
39+ .use(weaviateOptions.className )
40+ .withTenant(weaviateOptions.tenant )
41+ .withConsistencyLevel(weaviateOptions.consistencyLevel)
42+
43+ val results = collection.data.insertMany(batch.values.toList.asJava)
4544
4645 val IDs = batch.keys.toList
4746
48- if (results.hasErrors || results.getResult == null ) {
47+ if (results.errors() != null && ! results.errors().isEmpty ) {
4948 if (retries == 0 ) {
5049 throw WeaviateResultError (s " error getting result and no more retries left. " +
51- s " Error from Weaviate: ${results.getError.getMessages }" )
50+ s " Error from Weaviate: ${results.errors().asScala.mkString( " , " ) }" )
5251 }
5352 if (retries > 0 ) {
54- logError(s " batch error: ${results.getError.getMessages }, will retry " )
53+ logError(s " batch error: ${results.errors().asScala.mkString( " , " ) }, will retry " )
5554 logInfo(s " Retrying batch in ${weaviateOptions.retriesBackoff} seconds. Batch has following IDs: ${IDs }" )
5655 Thread .sleep(weaviateOptions.retriesBackoff * 1000 )
5756 writeBatch(retries - 1 )
5857 }
5958 } else {
60- val (objectsWithSuccess, objectsWithError) = results.getResult. partition(_.getResult.getErrors == null )
61- if (objectsWithError.size > 0 && retries > 0 ) {
62- val errors = objectsWithError.map(obj => s " ${obj.getId }: ${obj.getResult.getErrors.toString }" )
63- val successIDs = objectsWithSuccess.map(_.getId ).toList
59+ val (objectsWithSuccess, objectsWithError) = results.responses().asScala. partition(_.error() == null )
60+ if (objectsWithError.nonEmpty && retries > 0 ) {
61+ val errors = objectsWithError.map(obj => s " ${obj.uuid() }: ${obj.error() }" )
62+ val successIDs = objectsWithSuccess.map(_.uuid() ).toList
6463 logWarning(s " Successfully imported ${successIDs}. " +
6564 s " Retrying objects with an error. Following objects in the batch upload had an error: ${errors.mkString(" Array(" , " , " , " )" )}" )
6665 batch = batch -- successIDs
6766 writeBatch(retries - 1 )
68- } else if (objectsWithError.size > 0 ) {
69- val errorIds = objectsWithError.map(obj => obj.getId )
70- val errorMessages = objectsWithError.map(obj => obj.getResult.getErrors.toString ).distinct
67+ } else if (objectsWithError.nonEmpty ) {
68+ val errorIds = objectsWithError.map(obj => obj.uuid() )
69+ val errorMessages = objectsWithError.map(obj => obj.error() ).distinct
7170 throw WeaviateResultError (s " Error writing to weaviate and no more retries left. " +
7271 s " IDs with errors: ${errorIds.mkString(" Array(" , " , " , " )" )}. " +
7372 s " Error messages: ${errorMessages.mkString(" Array(" , " , " , " )" )}" )
@@ -79,17 +78,16 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
7978 }
8079 }
8180
82- private [spark] def buildWeaviateObject (record : InternalRow , weaviateClass : WeaviateClass = null ): WeaviateObject = {
83- var builder = WeaviateObject .builder.className(weaviateOptions.className)
84- if (weaviateOptions.tenant != null ) {
85- builder = builder.tenant(weaviateOptions.tenant)
86- }
81+ private [spark] def buildWeaviateObject (record : InternalRow , collectionConfig : CollectionConfig = null ): WeaviateObject [java.util.Map [String , Object ]] = {
82+ val builder : WeaviateObject .Builder [java.util.Map [String , Object ]] = new WeaviateObject .Builder ()
83+
8784 val properties = mutable.Map [String , AnyRef ]()
85+ var vector : Array [Float ] = null
8886 val vectors = mutable.Map [String , Array [Float ]]()
8987 val multiVectors = mutable.Map [String , Array [Array [Float ]]]()
9088 schema.zipWithIndex.foreach(field =>
9189 field._1.name match {
92- case weaviateOptions.vector => builder = builder.vector( record.getArray(field._2).toArray(FloatType ) )
90+ case weaviateOptions.vector => vector = record.getArray(field._2).toArray(FloatType )
9391 case key if weaviateOptions.vectors.contains(key) => vectors += (weaviateOptions.vectors(key) -> record.getArray(field._2).toArray(FloatType ))
9492 case key if weaviateOptions.multiVectors.contains(key) => {
9593 val multiVectorArrayData = record.get(field._2, ArrayType (ArrayType (FloatType ))) match {
@@ -105,34 +103,40 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
105103
106104 multiVectors += (weaviateOptions.multiVectors(key) -> multiVector)
107105 }
108- case weaviateOptions.id => builder = builder.id (record.getString(field._2))
109- case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false , field._1.name, weaviateClass )
106+ case weaviateOptions.id => builder.uuid (record.getString(field._2))
107+ case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false , field._1.name, collectionConfig )
110108 }
111109 )
110+
112111 if (weaviateOptions.id == null ) {
113- builder.id (java.util.UUID .randomUUID.toString)
112+ builder.uuid (java.util.UUID .randomUUID.toString)
114113 }
115114
115+ val allVectors = ListBuffer .empty[Vectors ]
116+ if (vector != null ) {
117+ allVectors += Vectors .of(vector)
118+ }
116119 if (vectors.nonEmpty) {
117- builder. vectors(vectors .map { case (key, arr) => key -> arr.map( Float .box ) }.asJava)
120+ allVectors ++= vectors.map { case (key, arr) => Vectors .of( key, arr) }
118121 }
119122 if (multiVectors.nonEmpty) {
120- builder. multiVectors(multiVectors .map { case (key, multiVector) => key -> multiVector.map { vec => { vec.map( Float .box ) }} }.toMap.asJava)
123+ allVectors ++= multiVectors.map { case (key, multiVector) => Vectors .of( key, multiVector) }
121124 }
122- builder.properties(properties.asJava).build
125+
126+ builder.tenant(weaviateOptions.tenant).properties(properties.asJava).vectors(allVectors.toSeq : _* ).build()
123127 }
124128
125- def getPropertyValue (index : Int , record : InternalRow , dataType : DataType , parseObjectArrayItem : Boolean , propertyName : String , weaviateClass : WeaviateClass ): AnyRef = {
129+ def getPropertyValue (index : Int , record : InternalRow , dataType : DataType , parseObjectArrayItem : Boolean , propertyName : String , collectionConfig : CollectionConfig ): AnyRef = {
126130 val valueFromField = getValueFromField(index, record, dataType, parseObjectArrayItem)
127- if (weaviateClass != null ) {
131+ if (collectionConfig != null ) {
128132 var dt = " "
129- weaviateClass.getProperties .forEach(p => {
130- if (p.getName == propertyName) {
133+ collectionConfig.properties() .forEach(p => {
134+ if (p.propertyName() == propertyName) {
131135 // we are just looking for geoCoordinates or phoneNumber type
132- dt = p.getDataType .get(0 )
136+ dt = p.dataTypes() .get(0 )
133137 }
134138 })
135- if ((dt == " geoCoordinates " || dt == " phoneNumber " ) && valueFromField.isInstanceOf [String ]) {
139+ if ((dt == WeaviateDataType . GEO_COORDINATES || dt == WeaviateDataType . PHONE_NUMBER ) && valueFromField.isInstanceOf [String ]) {
136140 return jsonToJavaMap(propertyName, valueFromField.toString).get
137141 }
138142 }
@@ -209,7 +213,7 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
209213 })
210214 }
211215 objList.asJava
212- case default => throw new SparkDataTypeNotSupported (s " DataType ${default} is not supported by Weaviate " )
216+ case default => throw SparkDataTypeNotSupported (s " DataType ${default} is not supported by Weaviate " )
213217 }
214218 }
215219
@@ -224,7 +228,7 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
224228 }
225229
226230 override def abort (): Unit = {
227- // TODO rollback previously written batch results if issue occured
231+ // TODO rollback previously written batch results if issue occurred
228232 logError(" Aborted data write" )
229233 }
230234}
0 commit comments