diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index 572b7c3cd..210ac5249 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -478,6 +478,7 @@ pub enum FunctionKind { IsSubclass, Dataclass, DataclassField, + DataclassReplace, /// `typing.dataclass_transform`. Note that this is `dataclass_transform` itself, *not* the /// decorator created by a `dataclass_transform(...)` call. See /// https://typing.python.org/en/latest/spec/dataclasses.html#specification. @@ -810,6 +811,7 @@ impl FunctionKind { ("builtins", None, "classmethod") => Self::ClassMethod, ("dataclasses", None, "dataclass") => Self::Dataclass, ("dataclasses", None, "field") => Self::DataclassField, + ("dataclasses", None, "replace") => Self::DataclassReplace, ("typing", None, "overload") => Self::Overload, ("typing", None, "override") => Self::Override, ("typing", None, "cast") => Self::Cast, @@ -840,6 +842,7 @@ impl FunctionKind { Self::ClassMethod => ModuleName::builtins(), Self::Dataclass => ModuleName::dataclasses(), Self::DataclassField => ModuleName::dataclasses(), + Self::DataclassReplace => ModuleName::dataclasses(), Self::DataclassTransform => ModuleName::typing(), Self::Final => ModuleName::typing(), Self::Overload => ModuleName::typing(), @@ -865,6 +868,7 @@ impl FunctionKind { Self::ClassMethod => Cow::Owned(Name::new_static("classmethod")), Self::Dataclass => Cow::Owned(Name::new_static("dataclass")), Self::DataclassField => Cow::Owned(Name::new_static("field")), + Self::DataclassReplace => Cow::Owned(Name::new_static("replace")), Self::DataclassTransform => Cow::Owned(Name::new_static("dataclass_transform")), Self::Final => Cow::Owned(Name::new_static("final")), Self::Overload => Cow::Owned(Name::new_static("overload")), @@ -890,6 +894,7 @@ impl FunctionKind { Self::ClassMethod => None, Self::Dataclass => None, Self::DataclassField => None, + Self::DataclassReplace => None, Self::DataclassTransform => None, Self::Final => None, Self::Overload => None, diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 28ef61aa4..4046b4a3a 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -1243,6 +1243,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { errors, ) } + Some(CalleeKind::Function(FunctionKind::DataclassReplace)) => { + self.call_dataclasses_replace( + &x.arguments.args, + &x.arguments.keywords, + x.arguments.range, + errors, + ) + } // Treat assert_type and reveal_type like pseudo-builtins for convenience. Note that we still // log a name-not-found error, but we also assert/reveal the type as requested. None if ty.is_error() && is_special_name(&x.func, "assert_type") => self diff --git a/pyrefly/lib/alt/class/dataclass.rs b/pyrefly/lib/alt/class/dataclass.rs index 03e625cad..1e93338cf 100644 --- a/pyrefly/lib/alt/class/dataclass.rs +++ b/pyrefly/lib/alt/class/dataclass.rs @@ -11,8 +11,11 @@ use dupe::Dupe; use pyrefly_python::dunder; use pyrefly_util::prelude::SliceExt; use ruff_python_ast::Arguments; +use ruff_python_ast::Expr; use ruff_python_ast::Expr::EllipsisLiteral; +use ruff_python_ast::Keyword; use ruff_python_ast::name::Name; +use ruff_text_size::Ranged; use ruff_text_size::TextRange; use starlark_map::small_map::SmallMap; use starlark_map::small_set::SmallSet; @@ -43,17 +46,21 @@ use crate::error::context::TypeCheckKind; use crate::types::callable::Callable; use crate::types::callable::FuncMetadata; use crate::types::callable::Function; +use crate::types::callable::FunctionKind; use crate::types::callable::Param; use crate::types::callable::ParamList; use crate::types::callable::Params; use crate::types::callable::Required; use crate::types::class::Class; +use crate::types::class::ClassKind; +use crate::types::class::ClassType; use crate::types::display::ClassDisplayContext; use crate::types::keywords::ConverterMap; use crate::types::keywords::DataclassFieldKeywords; use crate::types::keywords::TypeMap; use crate::types::literal::Lit; use crate::types::types::AnyStyle; +use crate::types::types::CalleeKind; use crate::types::types::Type; impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { @@ -180,6 +187,314 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Some(ClassSynthesizedFields::new(fields)) } + pub fn call_dataclasses_replace( + &self, + args: &[Expr], + keywords: &[Keyword], + range: TextRange, + errors: &ErrorCollector, + ) -> Type { + let call_args = args.map(CallArg::expr_maybe_starred); + let call_keywords = keywords.map(CallKeyword::new); + let obj_name = Name::new_static("obj"); + + let obj_ty = match args.first() { + Some(first) => self.expr_infer(first, errors), + None => Type::Any(AnyStyle::Implicit), + }; + let mut dataclass_types = SmallSet::new(); + let mut saw_any = false; + let mut saw_non_dataclass = false; + self.distribute_over_union(&obj_ty, |ty| { + match ty { + Type::Any(_) => saw_any = true, + Type::ClassType(cls) => { + let cls_metadata = self.get_metadata_for_class(cls.class_object()); + let allow_dataclass = cls_metadata.dataclass_metadata().is_some_and(|m| { + m.field_specifiers.iter().any(|k| { + matches!( + k, + CalleeKind::Function(FunctionKind::DataclassField) + | CalleeKind::Class(ClassKind::DataclassField) + ) + }) + }); + if allow_dataclass { + dataclass_types.insert(cls.clone()); + } else { + saw_non_dataclass = true; + } + } + _ => saw_non_dataclass = true, + } + ty.clone() + }); + + let metadata = FuncMetadata { + kind: FunctionKind::DataclassReplace, + flags: Default::default(), + }; + + if dataclass_types.is_empty() { + if !saw_any { + self.error( + errors, + args.first().map(|x| x.range()).unwrap_or(range), + ErrorInfo::Kind(ErrorKind::InvalidArgument), + "dataclasses.replace() should be called on dataclass instances".to_owned(), + ); + } + return self.callable_infer( + Callable::list( + ParamList::new(vec![ + Param::PosOnly(Some(obj_name.clone()), obj_ty.clone(), Required::Required), + Param::Kwargs(None, Type::Any(AnyStyle::Implicit)), + ]), + if saw_any { + Type::Any(AnyStyle::Implicit) + } else { + Type::any_error() + }, + ), + Some(&FunctionKind::DataclassReplace), + None, + None, + &call_args, + &call_keywords, + range, + errors, + errors, + None, + None, + None, + ); + } + + if saw_non_dataclass { + self.error( + errors, + args.first().map(|x| x.range()).unwrap_or(range), + ErrorInfo::Kind(ErrorKind::InvalidArgument), + format!( + "dataclasses.replace() expects a dataclass instance; got {}", + self.for_display(obj_ty.clone()) + ), + ); + return Type::any_error(); + } + + let mut overloads = Vec::new(); + for dt in &dataclass_types { + if let Some(function) = self.build_replace_function(&obj_name, dt, &metadata, errors) { + overloads.push(function); + } + } + if overloads.is_empty() { + return Type::any_error(); + } + + // Validate that every provided keyword is accepted by all overloads unless any overload + // uses **kwargs. + let allowed_common = overloads + .iter() + .map(|ov| self.allowed_replace_keywords(&ov.1.signature)) + .try_fold(None, |acc: Option>, (allowed, allow_any)| { + if allow_any { + return Err(()); + } + Ok(Some(match acc { + None => allowed, + Some(existing) => { + let mut intersection = SmallSet::new(); + for name in existing { + if allowed.contains(&name) { + intersection.insert(name); + } + } + intersection + } + })) + }); + if let Ok(Some(allowed)) = allowed_common { + for kw in keywords { + if let Some(id) = &kw.arg + && !allowed.contains(&id.id) + { + self.error( + errors, + kw.range, + ErrorInfo::Kind(ErrorKind::UnexpectedKeyword), + format!( + "Unexpected keyword argument `{}` in function `dataclasses.replace`", + id.id + ), + ); + return Type::any_error(); + } + } + } + + if overloads.len() == 1 { + let func = overloads.pop().unwrap(); + return self.callable_infer( + func.1.signature, + Some(&FunctionKind::DataclassReplace), + func.0.as_deref(), + None, + &call_args, + &call_keywords, + range, + errors, + errors, + None, + None, + None, + ); + } + + let Ok(overloads) = Vec1::try_from_vec(overloads) else { + return Type::any_error(); + }; + let (ret, sig) = self.call_overloads( + overloads, + metadata, + None, + &call_args, + &call_keywords, + range, + errors, + None, + None, + None, + ); + self.callable_infer( + sig, + Some(&FunctionKind::DataclassReplace), + None, + None, + &call_args, + &call_keywords, + range, + errors, + errors, + None, + None, + None, + ); + ret + } + + fn build_replace_callable( + &self, + obj_name: &Name, + dataclass_type: &ClassType, + errors: &ErrorCollector, + ) -> Option { + let metadata = self.get_metadata_for_class(dataclass_type.class_object()); + let dataclass_metadata = metadata.dataclass_metadata()?; + + let mut params = vec![Param::PosOnly( + Some(obj_name.clone()), + dataclass_type.clone().to_type(), + Required::Required, + )]; + + let subst = dataclass_type.targs().substitution_map(); + let self_type = dataclass_type.clone().to_type(); + let type_transform = |mut ty: Type| { + ty.subst_self_type_mut(&self_type); + ty.subst_mut(&subst); + ty + }; + + let strict_default = dataclass_metadata.kws.strict; + for (name, field, field_flags) in + self.iter_fields(dataclass_type.class_object(), dataclass_metadata, true) + { + if !field_flags.init { + continue; + } + + let strict = field_flags.strict.unwrap_or(strict_default); + let has_default = !field.is_init_var() || field_flags.default.is_some(); + if field_flags.init_by_name { + params.push(self.as_param( + &field, + &name, + has_default, + true, + strict, + field_flags.converter_param.clone(), + &type_transform, + errors, + )); + } + if let Some(alias) = &field_flags.init_by_alias { + params.push(self.as_param( + &field, + alias, + has_default, + true, + strict, + field_flags.converter_param.clone(), + &type_transform, + errors, + )); + } + } + if dataclass_metadata.kws.extra { + params.push(Param::Kwargs(None, Type::Any(AnyStyle::Implicit))); + } + + Some(Callable::list( + ParamList::new(params), + dataclass_type.clone().to_type(), + )) + } + + fn build_replace_function( + &self, + obj_name: &Name, + dataclass_type: &ClassType, + metadata: &FuncMetadata, + errors: &ErrorCollector, + ) -> Option> { + self.build_replace_callable(obj_name, dataclass_type, errors) + .map(|signature| { + TargetWithTParams( + None, + Function { + signature, + metadata: metadata.clone(), + }, + ) + }) + } + + fn allowed_replace_keywords(&self, callable: &Callable) -> (SmallSet, bool) { + let mut allowed = SmallSet::new(); + let mut allow_any = false; + match &callable.params { + Params::List(params) => { + for param in params.items() { + match param { + Param::KwOnly(name, ..) => { + allowed.insert(name.clone()); + } + Param::Kwargs(_, _) => { + allow_any = true; + } + _ => {} + } + } + } + _ => { + allow_any = true; + } + } + (allowed, allow_any) + } + pub fn validate_frozen_dataclass_inheritance( &self, cls: &Class, diff --git a/pyrefly/lib/test/dataclasses.rs b/pyrefly/lib/test/dataclasses.rs index f98c2a080..33e76afa0 100644 --- a/pyrefly/lib/test/dataclasses.rs +++ b/pyrefly/lib/test/dataclasses.rs @@ -67,6 +67,234 @@ Data(0, 1) # E: Argument `Literal[1]` is not assignable to parameter `y` with t "#, ); +testcase!( + test_replace, + r#" +from dataclasses import dataclass, replace + +@dataclass +class Foo: + x: int + y: str + +f = Foo(1, "a") + +replace(f, x="wrong") # E: Argument `Literal['wrong']` is not assignable to parameter `x` with type `int` in function `dataclasses.replace` +replace(f, z=3) # E: Unexpected keyword argument `z` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_initvar_default, + r#" +from dataclasses import dataclass, field, InitVar, replace + +@dataclass +class WithInitVarDefault: + x: int + y: InitVar[str] = "ok" + +w = WithInitVarDefault(0) +replace(w) +replace(w, y="new") + "#, +); + +testcase!( + test_replace_initvar_required, + r#" +from dataclasses import dataclass, InitVar, replace + +@dataclass +class Foo: + x: int + y: InitVar[int] + +f = Foo(1, 2) + +replace(f) # E: Missing argument `y` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_positional_args_rejected, + r#" +from dataclasses import dataclass, replace + +@dataclass +class Foo: + x: int + y: str + +f = Foo(1, "a") + +replace(f, "extra") # E: Expected 1 positional argument, got 2 + "#, +); + +testcase!( + test_replace_init_false_field_rejected, + r#" +from dataclasses import dataclass, field, replace + +@dataclass +class WithInitFalse: + x: int + y: int = field(init=False, default=5) + +g = WithInitFalse(1) + +replace(g, y=10) # E: Unexpected keyword argument `y` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_classvar_rejected, + r#" +from dataclasses import dataclass, replace +from typing import ClassVar + +@dataclass +class Config: + limit: int + MAX_ID: ClassVar[int] = 100 + +c = Config(10) +replace(c, limit=20) +replace(c, MAX_ID=200) # E: Unexpected keyword argument `MAX_ID` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_union_mixed_dataclass, + r#" +from dataclasses import dataclass, replace +from typing import Union + +@dataclass +class Foo: + x: int + +class Bar: + x: int + +def f(obj: Union[Foo, Bar]): + replace(obj, x=0) # E: dataclasses.replace() expects a dataclass instance; got Bar | Foo + "#, +); + +testcase!( + test_replace_union_two_dataclasses_rejects_bad_kw, + r#" +from dataclasses import dataclass, replace +from typing import Union + +@dataclass +class A: + x: int + +@dataclass +class B: + y: int + +def f(obj: Union[A, B]): + replace(obj, z=1) # E: Unexpected keyword argument `z` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_starred_args_rejected, + r#" +from dataclasses import dataclass, replace + +@dataclass +class Foo: + x: int + y: int + +foo = Foo(1, 2) + +replace(foo, *()) +replace(foo, **{"x": "bad"}) # E: Argument `str` is not assignable to parameter `x` with type `int` in function `dataclasses.replace` +replace(foo, **{"z": 0}) # E: Unexpected keyword argument `z` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_rejects_obj_keyword, + r#" +from dataclasses import dataclass, replace + +@dataclass +class Foo: + x: int + +foo = Foo(1) + +replace(foo, obj=foo) # E: Unexpected keyword argument `obj` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_generic_consistency, + r#" +from dataclasses import dataclass, replace +from typing import TypeVar, Generic + +T = TypeVar("T") + +@dataclass +class Box(Generic[T]): + item: T + +b = Box(item=1) +replace(b, item=2) +replace(b, item="wrong") # E: Argument `Literal['wrong']` is not assignable to parameter `item` with type `int` + "#, +); + +testcase!( + test_replace_union_of_two_dataclasses_rejects_bad_kw, + r#" +from dataclasses import dataclass, replace +from typing import Union + +@dataclass +class A: + x: int + +@dataclass +class B: + y: int + +def f(obj: Union[A, B]): + replace(obj, z=1) # E: Unexpected keyword argument `z` in function `dataclasses.replace` + "#, +); + +testcase!( + test_replace_does_not_treat_dataclass_transform_as_dataclass, + r#" +from dataclasses import replace +from typing import dataclass_transform + +@dataclass_transform() +def my_dc(cls): + return cls + +@my_dc +class Model: + x: int + y: str + + def __init__(self, x: int, y: str) -> None: ... + +m = Model(1, "a") + +replace(m, x=2) # E: dataclasses.replace() should be called on dataclass instances + "#, +); + testcase!( test_inheritance, r#"