为了账号安全,请及时绑定邮箱和手机立即绑定

给 Postgres 写一个向量插件 - 向量类型

标签:
数据库

在这篇文章中,我们将为 Postgres 实现 vector 类型:

CREATE TABLE items (v vector(3));

Postgres 扩展结构和 pgrx 包装器

在实现它之前,让我们先看看典型的扩展结构,以及 pgrx 如何为我们简化它。

典型的 Postgres 扩展可以大致分为 2 层:

  1. 实现,通常使用 C 等低级语言完成。
  2. 将实现粘合到 Postgres 的高级 SQL 语句。
  3. 指定扩展的一些基本属性的控制文件。

如果你看一下 pgvector 的源代码,这个 3 层结构非常明显,src 目录
用于 C 代码,sql 目录包含更高级的 SQL 胶水,还有一个
.control 文件。那么 pgrx 如何使扩展构建更容易?

  1. 它用 Rust 包装 Postgres C API

正如我们所说,即使我们用 Rust 构建扩展,Postgres 的 API 仍然是 C,pgrx 尝试将它们包装在 Rust 中,这样我们就不需要为 C 烦恼了。

  1. 如果可能,使用 Rust 宏生成 SQL 胶水

稍后我们将看到 pgrx 可以自动为我们生成 SQL 胶水。

  1. pgrx 为我们生成 .control 文件

CREATE TYPE vector

我们来定义我们的 Vector 类型,使用 std::vec::Vec 看起来非常简单,而且由于 vector 需要存储浮点数,我们在这里使用 f64:

struct Vector {
    value: Vec<f64>
}

然后呢?

用于创建新类型的 SQL 语句是 CREATE TYPE ...,从它的 文档,我们会知道我们正在实现的 vector 类型是一个 基类型,要创建基类型,需要支持函数 input_functionoutput_function。而且由于它需要采用使用 modifer 实现的维度参数(vector(DIMENSION)),因此还需要函数 type_modifier_input_functiontype_modifier_output_function。因此,我们需要为我们的 Vector 类型实现这 4 个函数。

input_function

引用文档,

input_function 将类型的外部文本表示转换为为该类型定义的运算符和函数使用的内部表示。

输入函数可以声明为采用一个 cstring 类型的参数,也可以声明为采用三个 cstringoidinteger 类型的参数。第一个参数是作为 C 字符串的输入文本,第二个参数是类型自己的 OID(数组类型除外,它们接收其元素类型的 OID),第三个是目标列的 typmod(如果已知)(如果未知,则传递 -1)。输入函数必须返回数据类型本身的值。

好的,从文档来看,这个 input_function 用于反序列化,serde 是 Rust 中最流行的反序列化库,所以让我们使用它。对于参数,由于 vector 需要类型修饰符,我们需要它接受 3 个参数。我们的 input_function 如下所示:

#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_input(
    input: &CStr,
    _oid: pg_sys::Oid,
    type_modifier: i32,
) -> Vector {
    let value = match serde_json::from_str::<Vec<f64>>(
        input.to_str().expect("expect input to be UTF-8 encoded"),
    ) {
        Ok(v) => v,
        Err(e) => {
            pgrx::error!("failed to deserialize the input string due to error {}", e)
        }
    };
    let dimension = match u16::try_from(value.len()) {
        Ok(d) => d,
        Err(_) => {
            pgrx::error!("this vector's dimension [{}] is too large", value.len());
        }
    };

    // cast should be safe as dimension should be a positive
    let expected_dimension = match u16::try_from(type_modifier) {
        Ok(d) => d,
        Err(_) => {
            panic!("failed to cast stored dimension [{}] to u16", type_modifier);
        }
    };

    // check the dimension
    if dimension != expected_dimension {
        pgrx::error!(
            "mismatched dimension, expected {}, found {}",
            expected_dimension,
            dimension
        );
    }

    Vector { value }
}

这有一大堆东西,让我们逐一研究一下。

#[pg_extern(immutable, strict, parallel_safe, require = [ "shell_type" ])]

如果你用 pg_extern 标记一个函数,那么 pgrx 会自动为你生成类似 CREATE FUNCTION <你的函数> 的 SQL,immutable, strict, parallel_safe 是你认为你的函数具有的属性,它们与 CREATE FUNCTION 文档 中列出的属性相对应。因为这个 Rust 宏用于生成 SQL,并且 SQL 可以相互依赖,所以这个 requires = [ "shell_type" ] 用于明确这种依赖关系。

shell_type 是另一个定义 shell 类型的 SQL 代码段的名称,什么是 shell 类型?它的行为就像一个占位符,这样我们在完全实现它之前就可以有一个 vector 类型来使用。此 #[pg_extern] 宏生成的 SQL 将是:

CREATE FUNCTION "vector_input"(
    "input" cstring,
    "_oid" oid,
    "type_modifier" INT
) RETURNS vector

如您所见,此函数 RETURNS vector,但在实现这 4 个必需函数之前,我们如何才能拥有 vector 类型?

circular_dependency

Shell 类型正是为此而生!我们可以定义一个 shell 类型(虚拟类型,不需要提供任何函数),并让我们的函数依赖于它:

vector_type

pgrx 不会为我们定义这个 shell 类型,我们需要在 SQL 中手动执行此操作,如下所示:

extension_sql!(
    r#"CREATE TYPE vector; -- shell type"#,
    name = "shell_type"
);

extension_sql!() 宏允许我们在 Rust 代码中编写 SQL,然后 pgrx 会将其包含在生成的 SQL 脚本中。name = "shell_type" 指定此 SQL 代码段的标识符,可用于引用它。我们的 vector_input() 函数依赖于此 shell 类型,因此它 requires = [ "shell_type" ]

fn vector_input(
    input: &CStr,
    _oid: pg_sys::Oid,
    type_modifier: i32,
) -> Vector {

input 参数是一个表示我们的向量输入文本的字符串,_oid_ 为前缀,因为我们不需要它。type_modifier 参数的类型为 i32,这就是类型修饰符在 Postgres 中的存储方式。当我们实现类型修饰符输入/输出函数时,我们将再次看到它。

let value = match serde_json::from_str::<Vec<f64>>(
    input.to_str().expect("expect input to be UTF-8coded"),
) {
    Ok(v) => v,
    Err(e) => {
        pgrx::error!("failed to deserialize the input string due to error {}", e)
    }
};

然后我们将 input 转换为 UTF-8 编码的 &str 并将其传递给 serde_json::from_str()。输入文本应该是 UTF-8 编码的,所以我们应该是安全的。如果在反序列化过程中发生任何错误,只需使用 pgrx::error!() 输出错误,它将在 error 级别记录并终止当前事务。

let dimension = match u16::try_from(value.len()) {
    Ok(d) => d,
    Err(_) => {
        pgrx::error!("此向量的维度 [{}] 太大", value.len());
    }
};

// cast should be safe as dimension should be a positive
let expected_dimension = match u16::try_from(type_modifier) {
    Ok(d) => d,
    Err(_) => {
        panic!("无法将存储的维度 [{}] 转换为 u16", type_modifier);
    }
};

我们支持的最大维度是 u16::MAX,我们这样做只是因为这是 pgvector 所做的。

// check the dimension
if dimension != expected_dimension {
    pgrx::error!(
        "mismatched dimension, expected {}, found {}",
        expected_dimension,
        dimension
    );
}

Vector { value }

最后,我们检查输入向量是否具有预期的维度,如果没有,则出错。否则,我们返回解析后的向量。

output_function

output_function 执行反向操作,它将给定的向量序列化为字符串。这是我们的实现:

#[pg_extern(immutable, strict, parallel_safe, require = [ "shell_type" ])]
fn vector_output(value: Vector) -> CString {
    let value_serialized_string = serde_json::to_string(&value).unwrap();
    CString::new(value_serialized_string).expect("中间不应该有 NUL")
}

我们只需序列化 Vec<f64> 并将其返回到 CString 中,非常简单。

type_modifier_input_function

type_modifier_input_function 应该解析输入修饰符,检查解析的修饰符,如果有效,则将其编码为整数,这是 Postgres 存储类型修饰符的方式。

一个类型可以接受多个类型修饰符,用 , 分隔,这就是我们在这里看到数组的原因。

#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_modifier_input(list: pgrx::datum::Array<&CStr>) -> i32 {
    if list.len() != 1 {
        pgrx::error!("too many modifiers, expect 1")
    }

    let modifier = list
        .get(0)
        .expect("should be Some as len = 1")
        .expect("type modifier cannot be NULL");
    let Ok(dimension) = modifier
        .to_str()
        .expect("expect type modifiers to be UTF-8 encoded")
        .parse::<u16>()
    else {
        pgrx::error!("invalid dimension value, expect a number in range [1, 65535]")
    };

    dimension as i32
}

实现很简单,vector 类型只接受 1 个修饰符,因此我们验证数组长度是否为 1。如果是,则尝试将其解析为 u16,如果没有发生错误,则将其作为 i32 返回,以便 Postgres 可以存储它。

type_modifier_output_function

#[pg_extern(immutable, strict, parallel_safe, require = [ "shell_type" ])]
fn vector_modifier_output(type_modifier: i32) -> CString {
    CString::new(format!("({})", type_modifier)).expect("No NUL in the middle")
}

type_modifier_output_function 的实现也很简单,只需返回一个格式为 (type_modifier) 的字符串即可。

所有必需的函数都已实现,现在我们终于可以 CREATE 它了!

// create the actual type, specifying the input and output functions
extension_sql!(
    r#"
CREATE TYPE vector (
    INPUT = vector_input,
    OUTPUT = vector_output,
    TYPMOD_IN = vector_modifier_input,
    TYPMOD_OUT = vector_modifier_output,
    STORAGE = external
);
"#,
    name = "concrete_type",
    creates = [Type(Vector)],
    requires = [
        "shell_type",
        vector_input,
        vector_output,
        vector_modifier_input,
        vector_modifier_output
    ]
);

大家可能对STORAGE = external这个参数比较陌生,我这里简单介绍一下。 Postgres 默认使用的存储引擎是堆,使用这个引擎,表的磁盘文件被分割成页,页的大小默认为 8192 字节,Postgres 希望至少 4 行可以存储在一个页中,如果某些列太大,无法容纳 4 行,Postgres 会将它们移动到外部表,即 TOAST 表。这个 STORAGE 参数控制 Postgres 在将列移动到 TOAST 表中时的行为。将其设置为 external 意味着如果该列太大,则可以将其移动到 TOAST 表中。

我们的第一次启动!

看起来一切就绪,让我们进行第一次启动吧!

$ cargo pgrx run
error[E0277]: the trait bound `for<'a> fn(&'a CStr, Oid, i32) -> Vector: FunctionMetadata<_>` is not satisfied
  --> pg_extension/src/vector_type.rs:87:1
   |
86 | #[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
   | --------------------------------------------------------------------------- in this procedural macro expansion
87 | fn vector_input(input: &CStr, _oid: pg_sys::Oid, type_modifier: i32) -> Vector {
   | ^^ the trait `FunctionMetadata<_>` is not implemented for `for<'a> fn(&'a CStr, Oid, i32) -> Vector`
   |
   = help: the following other types implement trait `FunctionMetadata<A>`:
             `unsafe fn() -> R` implements `FunctionMetadata<()>`
             `unsafe fn(T0) -> R` implements `FunctionMetadata<(T0,)>`
             `unsafe fn(T0, T1) -> R` implements `FunctionMetadata<(T0, T1)>`
             `unsafe fn(T0, T1, T2) -> R` implements `FunctionMetadata<(T0, T1, T2)>`
             `unsafe fn(T0, T1, T2, T3) -> R` implements `FunctionMetadata<(T0, T1, T2, T3)>`
             `unsafe fn(T0, T1, T2, T3, T4) -> R` implements `FunctionMetadata<(T0, T1, T2, T3, T4)>`
             `unsafe fn(T0, T1, T2, T3, T4, T5) -> R` implements `FunctionMetadata<(T0, T1, T2, T3, T4, T5)>`
             `unsafe fn(T0, T1, T2, T3, T4, T5, T6) -> R` implements `FunctionMetadata<(T0, T1, T2, T3, T4, T5, T6)>`
           and 25 others
   = note: this error originates in the attribute macro `pg_extern` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0599]: the function or associated item `entity` exists for struct `Vector`, but its trait bounds were not satisfied
  --> pg_extension/src/vector_type.rs:86:1
   |
22 | pub(crate) struct Vector {
   | ------------------------ function or associated item `entity` not found for this struct because it doesn't satisfy `Vector: SqlTranslatable`
...
86 | #[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `Vector` due to unsatisfied trait bounds
   |
   = note: the following trait bounds were not satisfied:
           `Vector: SqlTranslatable`
           which is required by `&Vector: SqlTranslatable`
note: the trait `SqlTranslatable` must be implemented
  --> $HOME/.cargo/registry/src/index.crates.io-6f17d22bba15001f/pgrx-sql-entity-graph-0.12.9/src/metadata/sql_translatable.rs:73:1
   |
73 | pub unsafe trait SqlTranslatable {
   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   = help: items from traits can only be used if the trait is implemented and in scope
   = note: the following traits define an item `entity`, perhaps you need to implement one of them:
           candidate #1: `FunctionMetadata`
           candidate #2: `SqlTranslatable`
   = note: this error originates in the attribute macro `pg_extern` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `Vector: RetAbi` is not satisfied
  --> pg_extension/src/vector_type.rs:87:73
   |
87 | fn vector_input(input: &CStr, _oid: pg_sys::Oid, type_modifier: i32) -> Vector {
   |                                                                         ^^^^^^ the trait `BoxRet` is not implemented for `Vector`, which is required by `Vector: RetAbi`
   |
   = help: the following other types implement trait `BoxRet`:
             &'a CStr
             &'a [u8]
             &'a str
             ()
             AnyArray
             AnyElement
             AnyNumeric
             BOX
           and 33 others
   = note: required for `Vector` to implement `RetAbi`

error[E0277]: the trait bound `Vector: RetAbi` is not satisfied
   --> pg_extension/src/vector_type.rs:87:80
    |
86  |   #[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
    |   --------------------------------------------------------------------------- in this procedural macro expansion
87  |   fn vector_input(input: &CStr, _oid: pg_sys::Oid, type_modifier: i32) -> Vector {
    |  ________________________________________________________________________________^
88  | |     let value = match serde_json::from_str::<Vec<f64>>(
89  | |         input.to_str().expect("expect input to be UTF-8 encoded"),
90  | |     ) {
...   |
110 | |     Vector { value }
111 | | }
    | |_^ the trait `BoxRet` is not implemented for `Vector`, which is required by `Vector: RetAbi`
    |
    = help: the following other types implement trait `BoxRet`:
              &'a CStr
              &'a [u8]
              &'a str
              ()
              AnyArray
              AnyElement
              AnyNumeric
              BOX
            and 33 others
    = note: required for `Vector` to implement `RetAbi`
    = note: this error originates in the attribute macro `pg_extern` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `fn(Vector) -> CString: FunctionMetadata<_>` is not satisfied
   --> pg_extension/src/vector_type.rs:114:1
    |
113 | #[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
    | --------------------------------------------------------------------------- in this procedural macro expansion
114 | fn vector_output(value: Vector) -> CString {
    | ^^ the trait `FunctionMetadata<_>` is not implemented for `fn(Vector) -> CString`
    |
    = help: the following other types implement trait `FunctionMetadata<A>`:
              `unsafe fn() -> R` implements `FunctionMetadata<()>`
              `unsafe fn(T0) -> R` implements `FunctionMetadata<(T0,)>`
              `unsafe fn(T0, T1) -> R` implements `FunctionMetadata<(T0, T1)>`
              `unsafe fn(T0, T1, T2) -> R` implements `FunctionMetadata<(T0, T1, T2)>`
              `unsafe fn(T0, T1, T2, T3) -> R` implements `FunctionMetadata<(T0, T1, T2, T3)>`
              `unsafe fn(T0, T1, T2, T3, T4) -> R` implements `FunctionMetadata<(T0, T1, T2, T3, T4)>`
              `unsafe fn(T0, T1, T2, T3, T4, T5) -> R` implements `FunctionMetadata<(T0, T1, T2, T3, T4, T5)>`
              `unsafe fn(T0, T1, T2, T3, T4, T5, T6) -> R` implements `FunctionMetadata<(T0, T1, T2, T3, T4, T5, T6)>`
            and 25 others
    = note: this error originates in the attribute macro `pg_extern` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0599]: the function or associated item `entity` exists for struct `Vector`, but its trait bounds were not satisfied
   --> pg_extension/src/vector_type.rs:113:1
    |
22  | pub(crate) struct Vector {
    | ------------------------ function or associated item `entity` not found for this struct because it doesn't satisfy `Vector: SqlTranslatable`
...
113 | #[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `Vector` due to unsatisfied trait bounds
    |
    = note: the following trait bounds were not satisfied:
            `Vector: SqlTranslatable`
            which is required by `&Vector: SqlTranslatable`
note: the trait `SqlTranslatable` must be implemented
   --> $HOME/.cargo/registry/src/index.crates.io-6f17d22bba15001f/pgrx-sql-entity-graph-0.12.9/src/metadata/sql_translatable.rs:73:1
    |
73  | pub unsafe trait SqlTranslatable {
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    = help: items from traits can only be used if the trait is implemented and in scope
    = note: the following traits define an item `entity`, perhaps you need to implement one of them:
            candidate #1: `FunctionMetadata`
            candidate #2: `SqlTranslatable`
    = note: this error originates in the attribute macro `pg_extern` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `Vector: ArgAbi<'_>` is not satisfied
   --> pg_extension/src/vector_type.rs:114:18
    |
113 | #[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
    | --------------------------------------------------------------------------- in this procedural macro expansion
114 | fn vector_output(value: Vector) -> CString {
    |                  ^^^^^ the trait `ArgAbi<'_>` is not implemented for `Vector`
    |
    = help: the following other types implement trait `ArgAbi<'fcx>`:
              &'fcx T
              &'fcx [u8]
              &'fcx str
              *mut FunctionCallInfoBaseData
              AnyArray
              AnyElement
              AnyNumeric
              BOX
            and 36 others
note: required by a bound in `pgrx::callconv::Args::<'a, 'fcx>::next_arg_unchecked`
   --> $HOME/.cargo/registry/src/index.crates.io-6f17d22bba15001f/pgrx-0.12.9/src/callconv.rs:931:41
    |
931 |     pub unsafe fn next_arg_unchecked<T: ArgAbi<'fcx>>(&mut self) -> Option<T> {
    |                                         ^^^^^^^^^^^^ required by this bound in `Args::<'a, 'fcx>::next_arg_unchecked`
    = note: this error originates in the attribute macro `pg_extern` (in Nightly builds, run with -Z macro-backtrace for more info)

修复错误!

特征绑定 for<'a> fn(&'a CStr, Oid, i32) -> Vector: FunctionMetadata<_> 不满足

好吧,这个错误的消息相当混乱,但确实有一个有用的提示:

note: the trait `SqlTranslatable` must be implemented
  --> $HOME/.cargo/registry/src/index.crates.io-6f17d22bba15001f/pgrx-sql-entity-graph-0.12.9/src/metadata/sql_translatable.rs:73:1
   |
73 | pub unsafe trait SqlTranslatable {
   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   = help: items from traits can only be used if the trait is implemented and in scope
   = note: the following traits define an item `entity`, perhaps you need to implement one of them:
           candidate #1: `FunctionMetadata`
           candidate #2: `SqlTranslatable`
   = note: this error originates in the attribute macro `pg_extern` (in Nightly builds, run with -Z macro-backtrace for more info)

此注释说我们应该为我们的Vector类型实现SqlTranslatable,特征是做什么用的?其文档说此特征代表“可以在 SQL 中表示的值”,这有点令人困惑。让我解释一下。由于 pgrx 会为我们生成 SQL,并且会使用我们的“Vector”类型,因此需要我们告诉它我们的类型在 SQL 中应该被调用什么,也就是将其转换为 SQL。

在 SQL 中,我们的 Vector 类型简称为 vector,因此我们的实现如下:

unsafe impl SqlTranslatable for Vector {
    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
        Ok(SqlMapping::As("vector".into()))
    }

    fn return_sql() -> Result<Returns, ReturnsError> {
        Ok(Returns::One(SqlMapping::As("vector".into())))
    }
}

特征绑定 Vector: RetAbiVector: ArgAbi 不满足

我们的 Vector 类型有其内存表示(如何在内存中布局),而 Postgres 有自己的方式(或规则)将事物存储在内存中。 vector_input() 函数向 Postgres 返回一个 Vector,而 vector_output() 函数接受一个 Vector 参数,该参数将由 Postgres 提供:

fn vector_input(input: &CStr, _oid: pg_sys::Oid, type_modifier: i32) -> Vector
fn vector_output(value: Vector) -> CString

使用 Vector 作为函数返回值和参数要求它能够按照 Postgres 的规则表示。Postgres 使用 Datum,这是所有 SQL 类型的二进制表示。因此,我们需要通过这两个特征在我们的 Vector 类型和 Datum 之间提供往返。

由于我们的 Vector 类型基本上是 std::vec::Vec<f64>,并且这两个特征是为 Vec 实现的,因此我们可以简单地使用这些实现:

impl FromDatum for Vector {
    unsafe fn from_polymorphic_datum(
        datum: pg_sys::Datum,
        is_null: bool,
        typoid: pg_sys::Oid,
    ) -> Option<Self>
    where
        Self: Sized,
    {
        let value = <Vec<f64> as FromDatum>::from_polymorphic_datum(datum, is_null, typoid)?;
        Some(Self { value })
    }
}

impl IntoDatum for Vector {
    fn into_datum(self) -> Option<pg_sys::Datum> {
        self.value.into_datum()
    }

    fn type_oid() -> pg_sys::Oid {
        rust_regtypein::<Self>()
    }
}

unsafe impl<'fcx> ArgAbi<'fcx> for Vector
where
    Self: 'fcx,
{
    unsafe fn unbox_arg_unchecked(arg: ::pgrx::callconv::Arg<'_, 'fcx>) -> Self {
        unsafe { arg.unbox_arg_using_from_datum().expect("expect it to be non-NULL") }
    }
}

unsafe impl BoxRet for Vector {
    unsafe fn box_into<'fcx>(
        self,
        fcinfo: &mut pgrx::callconv::FcInfo<'fcx>,
    ) -> pgrx::datum::Datum<'fcx> {
        match self.into_datum() {
            Some(datum) => unsafe {fcinfo.return_raw_datum(datum)}
            None => fcinfo.return_null()
        }
    }
}

FromDatumIntoDatum 也已实现,因为它们将在 RetAbiArgAbi 实现中使用。

让我们看看是否有任何错误:

$ cargo check
$ echo $?
0

太好了,没有错误,让我们再次启动它!

我们第二次启动!

$ cargo pgrx run
    ...
psql (17.2)
Type "help" for help.

pg_vector_ext=#

是的,“psql”启动没有任何问题!让我们启用扩展并创建一个带有“vector”列的表:

pg_vector_ext=# CREATE EXTENSION pg_vector_ext;
CREATE EXTENSION
pg_vector_ext=# CREATE TABLE items (v vector(3));
CREATE TABLE

到目前为止一切顺利,现在让我们插入一些数据:

pg_vector_ext=# INSERT INTO items values ('[1, 2, 3]');
ERROR:  failed to cast stored dimension [-1] to u16
LINE 1: INSERT INTO items values ('[1, 2, 3]');
                                  ^
DETAIL:
   0: std::backtrace_rs::backtrace::libunwind::trace
             at /rustc/f6e511eec7342f59a25f7c0534f1dbea00d01b14/library/std/src/../../backtrace/src/backtrace/libunwind.rs:116:5
             ...
  46: <unknown>
             at main.c:197:3

好吧,Postgres 慌了,因为 type_modifier 参数是 -1,无法转换为 u16。但是 type_modifier 怎么会是 -1,当且仅当类型修饰符未知时,它不是 -1 吗?类型修饰符 3 显然存储在元数据中,因此肯定是已知的。

我在这里卡了一段时间,猜猜怎么着,我不是唯一一个。恕我直言,这是一个错误,但并不是每个人都这么认为。让我们接受它只能是 -1 的事实,解决这个问题的方法是创建一个转换函数,将 Vector 转换为 Vector,您可以在其中访问存储的类型修饰符:

/// Cast a `vector` to a `vector`, the conversion is meaningless, but we do need
/// to do the dimension check here if we cannot get the `typmod` value in vector
/// type's input function.
#[pgrx::pg_extern(immutable, strict, parallel_safe, requires = ["concrete_type"])]
fn cast_vector_to_vector(vector: Vector, type_modifier: i32, _explicit: bool) -> Vector {
    let expected_dimension = u16::try_from(type_modifier).expect("invalid type_modifier") as usize;
    let dimension = vector.value.len();
    if vector.value.len() != expected_dimension {
        pgrx::error!(
            "mismatched dimension, expected {}, found {}",
            type_modifier,
            dimension
        );
    }

    vector
}

extension_sql!(
    r#"
    CREATE CAST (vector AS vector)
    WITH FUNCTION cast_vector_to_vector(vector, integer, boolean);
    "#,
    name = "cast_vector_to_vector",
    requires = ["concrete_type", cast_vector_to_vector]
);

在我们的 vector_input() 函数中,如果类型修饰符未知,则不会执行维度检查:

#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_input(input: &CStr, _oid: pg_sys::Oid, type_modifier: i32) -> Vector {
    let value = match serde_json::from_str::<Vec<f64>>(
        input.to_str().expect("expect input to be UTF-8 encoded"),
    ) {
        Ok(v) => v,
        Err(e) => {
            pgrx::error!("failed to deserialize the input string due to error {}", e)
        }
    };
    let dimension = match u16::try_from(value.len()) {
        Ok(d) => d,
        Err(_) => {
            pgrx::error!("this vector's dimension [{}] is too large", value.len());
        }
    };

    // check the dimension in INPUT function if we know the expected dimension.
    if type_modifier != -1 {
        let expected_dimension = match u16::try_from(type_modifier) {
            Ok(d) => d,
            Err(_) => {
                panic!("failed to cast stored dimension [{}] to u16", type_modifier);
            }
        };

        if dimension != expected_dimension {
            pgrx::error!(
                "mismatched dimension, expected {}, found {}",
                expected_dimension,
                dimension
            );
        }
    }
    // If we don't know the type modifier, do not check the dimension.

    Vector { value }
}

现在一切都应该正常工作了:

pg_vector_ext=# INSERT INTO items values ('[1, 2, 3]');
INSERT 0 1

INSERT 工作了, 让我们试试 SELECT:

pg_vector_ext=# SELECT * FROM items;
       v
---------------
 [1.0,2.0,3.0]
(1 row)

恭喜!您刚刚为 Postgres 添加了 vector 类型支持!以下是我们实现的完整代码:

src/vector_type.rs

//! This file defines a `Vector` type.

use pgrx::callconv::ArgAbi;
use pgrx::callconv::BoxRet;
use pgrx::datum::FromDatum;
use pgrx::datum::IntoDatum;
use pgrx::extension_sql;
use pgrx::pg_extern;
use pgrx::pg_sys;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::wrappers::rust_regtypein;
use std::ffi::CStr;
use std::ffi::CString;

/// The `vector` type
#[derive(Debug, serde::Deserialize, serde::Serialize)]
#[serde(transparent)]
pub(crate) struct Vector {
    pub(crate) value: Vec<f64>,
}

unsafe impl SqlTranslatable for Vector {
    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
        Ok(SqlMapping::As("vector".into()))
    }

    fn return_sql() -> Result<Returns, ReturnsError> {
        Ok(Returns::One(SqlMapping::As("vector".into())))
    }
}

impl FromDatum for Vector {
    unsafe fn from_polymorphic_datum(
        datum: pg_sys::Datum,
        is_null: bool,
        typoid: pg_sys::Oid,
    ) -> Option<Self>
    where
        Self: Sized,
    {
        let value = <Vec<f64> as FromDatum>::from_polymorphic_datum(datum, is_null, typoid)?;
        Some(Self { value })
    }
}

impl IntoDatum for Vector {
    fn into_datum(self) -> Option<pg_sys::Datum> {
        self.value.into_datum()
    }

    fn type_oid() -> pg_sys::Oid {
        rust_regtypein::<Self>()
    }
}

unsafe impl<'fcx> ArgAbi<'fcx> for Vector
where
    Self: 'fcx,
{
    unsafe fn unbox_arg_unchecked(arg: ::pgrx::callconv::Arg<'_, 'fcx>) -> Self {
        unsafe {
            arg.unbox_arg_using_from_datum()
                .expect("expect it to be non-NULL")
        }
    }
}

unsafe impl BoxRet for Vector {
    unsafe fn box_into<'fcx>(
        self,
        fcinfo: &mut pgrx::callconv::FcInfo<'fcx>,
    ) -> pgrx::datum::Datum<'fcx> {
        match self.into_datum() {
            Some(datum) => unsafe { fcinfo.return_raw_datum(datum) },
            None => fcinfo.return_null(),
        }
    }
}

extension_sql!(
    r#"
CREATE TYPE vector; -- shell type
"#,
    name = "shell_type",
    bootstrap // declare this extension_sql block as the "bootstrap" block so that it happens first in sql generation
);

#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_input(input: &CStr, _oid: pg_sys::Oid, type_modifier: i32) -> Vector {
    let value = match serde_json::from_str::<Vec<f64>>(
        input.to_str().expect("expect input to be UTF-8 encoded"),
    ) {
        Ok(v) => v,
        Err(e) => {
            pgrx::error!("failed to deserialize the input string due to error {}", e)
        }
    };
    let dimension = match u16::try_from(value.len()) {
        Ok(d) => d,
        Err(_) => {
            pgrx::error!("this vector's dimension [{}] is too large", value.len());
        }
    };

    // check the dimension in INPUT function if we know the expected dimension.
    if type_modifier != -1 {
        let expected_dimension = match u16::try_from(type_modifier) {
            Ok(d) => d,
            Err(_) => {
                panic!("failed to cast stored dimension [{}] to u16", type_modifier);
            }
        };

        if dimension != expected_dimension {
            pgrx::error!(
                "mismatched dimension, expected {}, found {}",
                expected_dimension,
                dimension
            );
        }
    }
    // If we don't know the type modifier, do not check the dimension.

    Vector { value }
}

#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_output(value: Vector) -> CString {
    let value_serialized_string = serde_json::to_string(&value).unwrap();
    CString::new(value_serialized_string).expect("there should be no NUL in the middle")
}

#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_modifier_input(list: pgrx::datum::Array<&CStr>) -> i32 {
    if list.len() != 1 {
        pgrx::error!("too many modifiers, expect 1")
    }

    let modifier = list
        .get(0)
        .expect("should be Some as len = 1")
        .expect("type modifier cannot be NULL");
    let Ok(dimension) = modifier
        .to_str()
        .expect("expect type modifiers to be UTF-8 encoded")
        .parse::<u16>()
    else {
        pgrx::error!("invalid dimension value, expect a number in range [1, 65535]")
    };

    dimension as i32
}

#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_modifier_output(type_modifier: i32) -> CString {
    CString::new(format!("({})", type_modifier)).expect("no NUL in the middle")
}

// create the actual type, specifying the input and output functions
extension_sql!(
    r#"
CREATE TYPE vector (
    INPUT = vector_input,
    OUTPUT = vector_output,
    TYPMOD_IN = vector_modifier_input,
    TYPMOD_OUT = vector_modifier_output,
    STORAGE = external
);
"#,
    name = "concrete_type",
    creates = [Type(Vector)],
    requires = [
        "shell_type",
        vector_input,
        vector_output,
        vector_modifier_input,
        vector_modifier_output
    ]
);

/// Cast a `vector` to a `vector`, the conversion is meaningless, but we do need
/// to do the dimension check here if we cannot get the `typmod` value in vector
/// type's input function.
#[pgrx::pg_extern(immutable, strict, parallel_safe, requires = ["concrete_type"])]
fn cast_vector_to_vector(vector: Vector, type_modifier: i32, _explicit: bool) -> Vector {
    let expected_dimension = u16::try_from(type_modifier).expect("invalid type_modifier") as usize;
    let dimension = vector.value.len();
    if vector.value.len() != expected_dimension {
        pgrx::error!(
            "mismatched dimension, expected {}, found {}",
            type_modifier,
            dimension
        );
    }

    vector
}

extension_sql!(
    r#"
    CREATE CAST (vector AS vector)
    WITH FUNCTION cast_vector_to_vector(vector, integer, boolean);
    "#,
    name = "cast_vector_to_vector",
    requires = ["concrete_type", cast_vector_to_vector]
);
点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消