Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 78 additions & 3 deletions pypika_tortoise/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import TYPE_CHECKING, Any

from .context import SqlContext
from .enums import SqlTypes
from .terms import AggregateFunction, Function, Star, Term
from .enums import Dialects, SqlTypes
from .exceptions import DialectNotSupported
from .terms import AggregateFunction, Function, Star, Term, ValueWrapper
from .utils import builder

if TYPE_CHECKING:
Expand Down Expand Up @@ -242,8 +243,82 @@ def __init__(self, term: Any, alias: str | None = None) -> None:


class Trim(Function):
def __init__(
self,
term: Any,
trim_chars: str = " ",
alias: str | None = None,
) -> None:
args = [term]
if trim_chars != " ":
args.append(ValueWrapper(trim_chars))

super().__init__("TRIM", *args, alias=alias)

def get_function_sql(self, ctx: SqlContext) -> str:
if len(self.args) == 1:
return super().get_function_sql(ctx)

args_sql = [self.get_arg_sql(arg, ctx) for arg in self.args]
if ctx.dialect == Dialects.SQLITE:
args = ",".join(args_sql)
else:
args = f"BOTH {args_sql[1]} FROM {args_sql[0]}"

return "{name}({args})".format(
name=self.get_dialect_special_name(ctx.dialect) or self.name,
args=args,
)


class LTrim(Function):
def __init__(self, term: Any, alias: str | None = None) -> None:
Comment thread
waketzheng marked this conversation as resolved.
super().__init__("LTRIM", term, alias=alias)


class RTrim(Function):
def __init__(self, term: Any, alias: str | None = None) -> None:
super().__init__("TRIM", term, alias=alias)
super().__init__("RTRIM", term, alias=alias)


class _Pad(Function):
db_function: str

def __init__(
self, term: Any, length: int, fill_text: str = " ", alias: str | None = None
) -> None:
super().__init__(
self.db_function,
term,
ValueWrapper(length),
ValueWrapper(fill_text),
alias=alias,
)

def get_sql(self, ctx: SqlContext) -> str:
if ctx.dialect in [Dialects.SQLITE, Dialects.MSSQL]:
raise DialectNotSupported(f"{self.db_function} is not supported in {ctx.dialect}.")

return super().get_sql(ctx)


class LPad(_Pad):
db_function = "LPAD"


class RPad(_Pad):
db_function = "RPAD"


class Replace(Function):
def __init__(self, term: Any, search: str, replacement: str, alias: str | None = None) -> None:
super().__init__(
"REPLACE",
term,
ValueWrapper(search),
ValueWrapper(replacement),
alias=alias,
)


class SplitPart(Function):
Expand Down
115 changes: 115 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from pypika_tortoise import Table as T
from pypika_tortoise import functions as fn
from pypika_tortoise.context import DEFAULT_SQL_CONTEXT
from pypika_tortoise.dialects.mssql import MSSQLQuery
from pypika_tortoise.dialects.postgresql import PostgreSQLQuery
from pypika_tortoise.dialects.sqlite import SQLLiteQuery
from pypika_tortoise.enums import SqlTypes
from pypika_tortoise.exceptions import DialectNotSupported


class FunctionTests(unittest.TestCase):
Expand Down Expand Up @@ -521,6 +524,118 @@ def test__substring(self):

self.assertEqual('SELECT SUBSTRING("foo",2,6) FROM "abc"', str(q))

def test__trim__field(self):
Comment thread
waketzheng marked this conversation as resolved.
q = Q.from_(self.t).select(fn.Trim(self.t.foo))

self.assertEqual('SELECT TRIM("foo") FROM "abc"', str(q))

def test__trim__field__chars(self):
q = Q.from_(self.t).select(fn.Trim(self.t.foo, trim_chars="x"))

self.assertEqual(
'SELECT TRIM(BOTH \'x\' FROM "foo") FROM "abc"', q.get_sql(PostgreSQLQuery.SQL_CONTEXT)
)

def test__trim__field__chars__sqlite(self):
q = Q.from_(self.t).select(fn.Trim(self.t.foo, trim_chars="x"))

self.assertEqual('SELECT TRIM("foo",\'x\') FROM "abc"', str(q))

def test__trim__str(self):
q = Q.select(fn.Trim(" abc "))

self.assertEqual("SELECT TRIM(' abc ')", str(q))

def test__trim__str__chars(self):
q = Q.select(fn.Trim("xxabcxx", trim_chars="x"))

self.assertEqual(
"SELECT TRIM(BOTH 'x' FROM 'xxabcxx')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT)
)

def test__trim__str__chars__sqlite(self):
q = Q.select(fn.Trim("xxabcxx", trim_chars="x"))

self.assertEqual("SELECT TRIM('xxabcxx','x')", str(q))

def test__ltrim__str(self):
q = Q.select(fn.LTrim(" abc"))

self.assertEqual("SELECT LTRIM(' abc')", str(q))

def test__ltrim__field(self):
q = Q.from_(self.t).select(fn.LTrim(self.t.foo))

self.assertEqual('SELECT LTRIM("foo") FROM "abc"', str(q))

def test__rtrim__str(self):
q = Q.select(fn.RTrim("abc "))

self.assertEqual("SELECT RTRIM('abc ')", str(q))

def test__rtrim__field(self):
q = Q.from_(self.t).select(fn.RTrim(self.t.foo))

self.assertEqual('SELECT RTRIM("foo") FROM "abc"', str(q))

def test__lpad__str(self):
q = Q.select(fn.LPad("abc", 5))

self.assertEqual("SELECT LPAD('abc',5,' ')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT))

def test__lpad__str_with_fill(self):
q = Q.select(fn.LPad("abc", 5, "x"))

self.assertEqual("SELECT LPAD('abc',5,'x')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT))

def test__lpad__field(self):
q = Q.from_(self.t).select(fn.LPad(self.t.foo, 10, "-"))

self.assertEqual(
'SELECT LPAD("foo",10,\'-\') FROM "abc"', q.get_sql(PostgreSQLQuery.SQL_CONTEXT)
)

def test__lpad__sqlite__mssql__raises(self):
q = Q.select(fn.LPad("abc", 5))

for dialect in [SQLLiteQuery.SQL_CONTEXT, MSSQLQuery.SQL_CONTEXT]:
with self.subTest(dialect=dialect), self.assertRaises(DialectNotSupported):
q.get_sql(dialect)

def test__rpad__str(self):
q = Q.select(fn.RPad("abc", 5))

self.assertEqual("SELECT RPAD('abc',5,' ')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT))

def test__rpad__str_with_fill(self):
q = Q.select(fn.RPad("abc", 5, "x"))

self.assertEqual("SELECT RPAD('abc',5,'x')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT))

def test__rpad__field(self):
q = Q.from_(self.t).select(fn.RPad(self.t.foo, 10, "-"))

self.assertEqual(
'SELECT RPAD("foo",10,\'-\') FROM "abc"', q.get_sql(PostgreSQLQuery.SQL_CONTEXT)
)

def test__rpad__sqlite__mssql__raises(self):
q = Q.select(fn.RPad("abc", 5))

for dialect in [SQLLiteQuery.SQL_CONTEXT, MSSQLQuery.SQL_CONTEXT]:
with self.subTest(dialect=dialect), self.assertRaises(DialectNotSupported):
q.get_sql(dialect)

def test__replace__str(self):
q = Q.select(fn.Replace("abcde", "cd", "xx"))

self.assertEqual("SELECT REPLACE('abcde','cd','xx')", str(q))

def test__replace__field(self):
q = Q.from_(self.t).select(fn.Replace(self.t.foo, "old", "new"))

self.assertEqual("SELECT REPLACE(\"foo\",'old','new') FROM \"abc\"", str(q))


class CastTests(unittest.TestCase):
t = T("abc")
Expand Down
Loading