Implementing a new expression in Daft involves several steps across different parts of the codebase. This guide will walk you through the process using the cbrt (cube root) function as an example. By following these steps, you can implement your own custom expressions in Daft.

Step 1: Implement the Core Functionality

First, you need to implement the core functionality of your expression in the daft-core crate.

  1. Add the function to the appropriate module in src/daft-core/src/array/ops/:

    // src/daft-core/src/array/ops/cbrt.rs
    use num_traits::Float;
    use common_error::DaftResult;
    use crate::{array::DataArray, datatypes::DaftNumericType};
    
    impl<T> DataArray<T>
    where
        T: DaftNumericType,
        T::Native: Float,
    {
        pub fn cbrt(&self) -> DaftResult<Self> {
            self.apply(|v| v.cbrt())
        }
    }
    
  2. Add the new module to src/daft-core/src/array/ops/mod.rs:

    mod cbrt;
    
  3. Implement the function for the Series type in src/daft-core/src/series/ops/:

    // src/daft-core/src/series/ops/cbrt.rs
    use common_error::DaftResult;
    use crate::datatypes::DataType;
    use crate::series::array_impl::IntoSeries;
    use crate::series::Series;
    
    impl Series {
        pub fn cbrt(&self) -> DaftResult<Series> {
            let casted_dtype = self.to_floating_data_type()?;
            let casted_self = self
                .cast(&casted_dtype)
                .expect("Casting numeric types to their floating point analogues should not fail");
            match casted_dtype {
                DataType::Float32 => Ok(casted_self.f32().unwrap().cbrt()?.into_series()),
                DataType::Float64 => Ok(casted_self.f64().unwrap().cbrt()?.into_series()),
                _ => unreachable!(),
            }
        }
    }
    
  4. Add the new module to src/daft-core/src/series/ops/mod.rs:

    pub mod cbrt;
    

Step 2: Implement the DSL Function

Next, implement the DSL (Domain-Specific Language) function in the daft-dsl crate.

  1. Create a new file for your function in src/daft-dsl/src/functions/numeric/:

    // src/daft-dsl/src/functions/numeric/cbrt.rs
    use common_error::{DaftError, DaftResult};
    use daft_core::{datatypes::Field, schema::Schema, series::Series};
    use super::super::FunctionEvaluator;
    use crate::functions::FunctionExpr;
    use crate::ExprRef;
    
    pub struct CbrtEvaluator {}
    
    impl FunctionEvaluator for CbrtEvaluator {
        fn name(&self) -> &'static str {
            "cbrt"
        }
    
        fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
            match inputs {
                [first] => {
                    let field = first.to_field(schema)?;
                    let dtype = field.dtype.to_floating_representation()?;
                    Ok(Field::new(field.name, dtype))
                }
                _ => Err(DaftError::SchemaMismatch(format!(
                    "Expected 1 input arg, got {}",
                    inputs.len()
                ))),
            }
        }
    
        fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
            match inputs {
                [first] => first.cbrt(),
                _ => Err(DaftError::SchemaMismatch(format!(
                    "Expected 1 input arg, got {}",
                    inputs.len()
                ))),
            }
        }
    }
    
  2. Add the new function to src/daft-dsl/src/functions/numeric/mod.rs:

    pub enum NumericExpr {
        // ... other expressions ...
        Cbrt,
    }
    
    impl NumericExpr {
        #[inline]
        pub fn get_evaluator(&self) -> &dyn FunctionEvaluator {
            match self {
                // ... other expressions ...
                NumericExpr::Cbrt => &CbrtEvaluator {},
            }
        }
    }
    

Step 3: Implement the Python Bindings

Finally, implement the Python bindings in the daft-functions crate.

  1. Create a new file for your function in src/daft-functions/src/numeric/:

    // src/daft-functions/src/numeric/cbrt.rs
    use common_error::{DaftError, DaftResult};
    use daft_core::{datatypes::Field, schema::Schema, Series};
    use daft_dsl::{functions::ScalarUDF, ExprRef};
    use serde::{Deserialize, Serialize};
    
    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
    struct CbrtFunction;
    
    #[typetag::serde]
    impl ScalarUDF for CbrtFunction {
        // ... implement the required methods ...
    }
    
    #[cfg(feature = "python")]
    pub mod python {
        use daft_dsl::{functions::ScalarFunction, python::PyExpr, ExprRef};
        use pyo3::{pyfunction, PyResult};
    
        use super::CbrtFunction;
    
        #[pyfunction]
        pub fn cbrt(expr: PyExpr) -> PyResult<PyExpr> {
            let scalar_function = ScalarFunction::new(CbrtFunction, vec![expr.into()]);
            let expr = ExprRef::from(scalar_function);
            Ok(expr.into())
        }
    }
    
  2. Add the new module to src/daft-functions/src/numeric/mod.rs:

    pub mod cbrt;
    
  3. Register the Python function in src/daft-functions/src/lib.rs:

    pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
        // ... other registrations ...
        parent.add_wrapped(wrap_pyfunction!(numeric::cbrt::python::cbrt))?;
        Ok(())
    }
    

Step 4: Update Python Type Hints

Don't forget to update the Python type hints in daft/daft.pyi:

def cbrt(expr: PyExpr) -> PyExpr: ...

class PySeries:
    # ... other methods ...
    def cbrt(self) -> PySeries: ...

Step 5: Add Tests