Using Kedro to process datasets in batches asynchronously
In this section you will deep-dive into how we optimized our embedding generation whilst leveraging PySpark and Kedro functionalities.
Preliminaries
Our organisation has been focussing on large-scale data enrichment. One of the problems we've ran into at multiple occasions is using an external API to enrich elements in a huge dataset in batched manner, i.e., millions of rows. The problem becomes more prevalent whenever the upstream API has rate limiting enabled, e.g., OpenAI for computing embeddings.
Initial approach
We had an initial spark operation that would perform batching and invoke a UDF that implements the OpenAI invocation on a batch.
def compute_embeddings(
input: DataFrame,
features: List[str],
fn: Callable # pyspark udf to call OpenAI
):
window = Window.orderBy(F.lit(1))
res = (
input.withColumn("row_num", F.row_number().over(window))
.withColumn("batch", F.floor((F.col("row_num") - 1) / batch_size))
.withColumn("input", F.concat(*[F.coalesce(F.col(feature), F.lit("")) for feature in features]),
)
.groupBy("batch")
.agg(
F.collect_list("id").alias("id"),
F.collect_list("input").alias("input"),
)
.withColumn(attribute, batch_udf(F.col("input")))
.withColumn("_conc", F.arrays_zip(F.col("id"), F.col(attribute)))
.withColumn("exploded", F.explode(F.col("_conc")))
.select(
F.col("exploded.id").alias("id"),
F.col(f"exploded.{attribute}").alias(attribute),
)
.join(input, on="id")
)
return res
The operation above has huge hardware requirements, and the parallelism is correlated to the number of partitions in PySpark. In multiple occasions the function above would come to a halt, signaling further resource issues. We felt like using PySpark for this was not the right choice.
Towards a better approach
Hive partitioning: our approach relies on hive partitioning to produce a batched dataset. Hive partitioning is an idea where a column of the data is "promoted" to directories on file storage. For instance, if we use hive partitioning for a table with columns [id, batch], the file on object storage will be stored as follows:
. └── dataset/ ├── batch=1/ │ └── df.parquet ├── batch=2/ │ └── df.parquet └── ...
Bucketizing the input dataframe
As a first step, we're making sure to add a batch column to our data.
def bucketize_df(df: DataFrame, bucket_size: int, input_features: List[str], max_input_len: int):
"""Function to bucketize input dataframe.
Function bucketizes the input dataframe in N buckets, each of size `bucket_size`
elements. Moreover, it concatenates the `features` into a single column and limits the
length to `max_input_len`.
Args:
df: Dataframe to bucketize
attributes: to keep
bucket_size: size of the buckets
"""
# Retrieve number of elements
num_elements = df.count()
num_buckets = (num_elements + bucket_size - 1) // bucket_size
# Construct df to bucketize
spark_session: SparkSession = SparkSession.builder.getOrCreate()
# Bucketize df
# NOTE: Alternatively can use .repartition?
buckets = spark_session.createDataFrame(
data=[(bucket, bucket * bucket_size, (bucket + 1) * bucket_size) for bucket in range(num_buckets)],
schema=["bucket", "min_range", "max_range"],
)
# Order and bucketize elements
return (
df.withColumn("row_num", F.row_number().over(Window.orderBy("id")) - F.lit(1))
.join(buckets, on=[(F.col("row_num") >= (F.col("min_range"))) & (F.col("row_num") < F.col("max_range"))])
# Concat input
.withColumn(
"text_to_embed",
F.concat(*[F.coalesce(F.col(feature), F.lit("")) for feature in input_features]),
)
# Clip max. length
.withColumn("text_to_embed", F.substring(F.col("text_to_embed"), 1, max_input_len))
.select("id", *input_features, "text_to_embed", "bucket")
)
We're now using the native Spark paritioning column to ensure the dataset is written in a partitioned manner.
embeddings.feat.bucketized_nodes@spark:
<<: *_spark_parquet
filepath: ${globals:paths.tmp}/feat/bucketized_nodes
save_args:
mode: overwrite
partitionBy:
- bucket
Loading the input as a partitioned DF
Next, we wish to process the dataframe in shards, we will be using Kedro's PartitionedDataset. The partitioned dataset is interesting in the sense that it does not load any data, but rather provides a dictionary as input to the node, mapping the paths of the dataset to it's shards' load function. It's important to remember that the data is only loaded whenever the load function is invoked.
Note: to allow reading/loading the same underlying dataset, in different format, in Kedro, we're using transcoding. This allows for re-defining the way the dataset to be loaded, while ensuring the Kedro dag does not become disconnected.
embeddings.feat.bucketized_nodes@partitioned:
type: matrix_gcp_datasets.gcpPartitionedAsyncParallelDataset
path: ${globals:paths.tmp}/feat/bucketized_nodes
dataset:
# NOTE: Switching between spark/pandas thanks to underlying parquet structure
<<: *_pandas_parquet
filename_suffix: ".parquet"
Setting up the processing logic
Next, let's set up the processing logic, we're codifying this as function that process an individual dataframe.
# Configuration for node embeddings
embeddings.node:
# Following configuration ensures that OpenAI requests are
# batches in batches of 500 input elements, where each input
# element is clipped at 100 characters. Max. token limit for
# embeddings call is roughly 8500 tokens, approx (100 * 500) / 6.
batch_size: 500
max_input_len: 100
input_features: ["category", "name"]
model:
object: langchain_openai.OpenAIEmbeddings
model: text-embedding-3-small
openai_api_key: ${oc.env:OPENAI_API_KEY}
dimensions: 100
timeout: 10
# NOTE: Should we call out already that native dataset does not support async?
@retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(3))
async def compute_df_embeddings_async(df: pd.DataFrame, embedding_model) -> pd.DataFrame:
try:
# Embed entities in batch mode
combined_texts = df["text_to_embed"].tolist()
df["embedding"] = await embedding_model.aembed_documents(combined_texts)
except Exception as e:
print(f"Exception occurred: {e}")
raise e
# Drop added column
df = df.drop(columns=["text_to_embed"])
return df
Connecting the dots
We've now setup the a Kedro dataset to load the hive partitioned dataset as a Kedro Partitioned dataset, and we wish to produce a new Kedro PartitionedDataset that includes the embedding.
embeddings.feat.graph.node_embeddings@partitioned:
type: matrix_gcp_datasets.gcpPartitionedAsyncParallelDataset
overwrite: True # important otherwise not properly reset on rerun
path: ${globals:paths.tmp}/feat/tmp_nodes_with_embeddings
dataset:
<<: *_pandas_parquet
filename_suffix: ".parquet"
Saving to a PartitionedDatset follows the same idea, where we output a dictionary of paths mapped to their save() function from the Kedro node. The clue is now to setup this dictionary in a a manner that save() function is responsible for invoking the load() and compute_df_embeddings_async() function.
@inject_object()
def compute_embeddings(
dfs: Dict[str, Any],
model: Dict[str, Any],
):
"""Function to bucketize input data.
Args:
dfs: mapping of paths to df load functions
model: model to run
"""
def _func(dataframe: pd.DataFrame):
# NOTE: Very important to bake in the df=dataframe to avoid reference issues
return lambda df=dataframe: compute_df_embeddings_async(df(), model)
shards = {}
for path, df in dfs.items():
# Little bit hacky, but extracting batch from hive partitioning for input path
# As we know the input paths to this dataset are of the format /shard={num}
bucket = path.split("/")[0].split("=")[1]
# Invoke function to compute embeddings
shard_path = f"bucket={bucket}/shard"
shards[shard_path] = _func(df)
return shards
Parallelising the PartitionedDataset
Kedro's default dataset does not support parallelizing writing the individual shards, hence why I've extended it's save() behaviour to execute on the async event loop. A Semaphore was used to limit the maximum number of threads running at a given time, thereby allowing it to work with APIs that have rate limiting enabled.
class PartitionedAsyncParallelDataset(PartitionedDataset):
"""
Custom implementation of the ParallelDataset that allows concurrent processing.
"""
def _save(self, data: dict[str, Any], max_workers: int = 10, timeout: int = 60) -> None:
if self._overwrite and self._filesystem.exists(self._normalized_path):
self._filesystem.rm(self._normalized_path, recursive=True)
# Helper function to process a single partition
async def process_partition(sem, partition_id, partition_data):
async with sem:
try:
# Set up arguments and path
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
# Evaluate partition data if it's callable
if callable(partition_data):
partition_data = await partition_data() # noqa: PLW2901
else:
raise RuntimeError("not callable")
# Save the partition data
dataset.save(partition_data)
except Exception as e:
print(f"Error in process_partition with partition {partition_id}: {e}")
raise
# Define function to run asyncio tasks within a synchronous function
def run_async_tasks():
# Create an event loop and a thread pool executor for async execution
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
sem = asyncio.Semaphore(max_workers)
tasks = [
loop.create_task(process_partition(sem, partition_id, partition_data))
for partition_id, partition_data in sorted(data.items())
]
# Track progress with tqdm as tasks complete
with tqdm(total=len(tasks), desc="Saving partitions") as progress_bar:
async def monitor_tasks():
for task in asyncio.as_completed(tasks):
try:
await asyncio.wait_for(task, timeout)
except asyncio.TimeoutError as e:
print(f"Timeout error: partition processing took longer than {timeout} seconds.")
raise e
except Exception as e:
print(f"Error processing partition in tqdm loop: {e}")
raise e
finally:
progress_bar.update(1)
# Run the monitoring coroutine
try:
loop.run_until_complete(monitor_tasks())
finally:
loop.close()
run_async_tasks()
self._invalidate_caches()
Wrapping up
We now have an output dataset with hive partitioning that contains the result, we can now use the SparkDataset to load it as a full table, i.e., promote the directory to a column for downstream processing.
embeddings.feat.graph.node_embeddings@spark:
<<: *_spark_parquet
filepath: ${globals:paths.tmp}/feat/tmp_nodes_with_embeddings
The final Kedro pipeline looks as follows:
def create_pipeline(**kwargs) -> Pipeline:
return pipeline(
[
node(
func=nodes.bucketize_df,
inputs={
"df": "filtering.prm.filtered_nodes",
"input_features": "params:embeddings.node.input_features",
"bucket_size": "params:embeddings.node.batch_size",
"max_input_len": "params:embeddings.node.max_input_len",
},
outputs="embeddings.feat.bucketized_nodes@spark",
name="bucketize_nodes",
tags=["argowf.fuse", "argowf.fuse-group.node_embeddings"],
),
# Compute embeddings
node(
func=nodes.compute_embeddings,
inputs={
"dfs": "embeddings.feat.bucketized_nodes@partitioned",
"model": "params:embeddings.node.model",
},
outputs="embeddings.feat.graph.node_embeddings@partitioned",
name="add_node_embeddings",
tags=["argowf.fuse", "argowf.fuse-group.node_embeddings"],
)
# Next step inputs `embeddings.feat.graph.node_embeddings@spark` to ensure
# table is loaded as a full dataset.
]
)