patch 9.0.1959: Vim9: methods parameters and types are covariant

Problem:  Vim9: methods parameters and types are covariant
Solution: Support contra-variant type check for object method arguments
          (similar to Dart).

closes: #12965
closes: #13221

Signed-off-by: Christian Brabandt <cb@256bit.org>
Co-authored-by: Yegappan Lakshmanan <yegappan@yahoo.com>
diff --git a/src/proto/vim9class.pro b/src/proto/vim9class.pro
index 9edf354..31e2be7 100644
--- a/src/proto/vim9class.pro
+++ b/src/proto/vim9class.pro
@@ -29,5 +29,5 @@
 void method_not_found_msg(class_T *cl, vartype_T v_type, char_u *name, size_t len);
 void member_not_found_msg(class_T *cl, vartype_T v_type, char_u *name, size_t len);
 void f_instanceof(typval_T *argvars, typval_T *rettv);
-int class_instance_of(class_T *cl, class_T *other_cl);
+int class_instance_of(class_T *cl, class_T *other_cl, int covariance_check);
 /* vim: set ft=c : */
diff --git a/src/structs.h b/src/structs.h
index 009f16e..d05ae87 100644
--- a/src/structs.h
+++ b/src/structs.h
@@ -4798,14 +4798,19 @@
     WT_ARGUMENT,
     WT_VARIABLE,
     WT_MEMBER,
-    WT_METHOD,
+    WT_METHOD,		// object method
+    WT_METHOD_ARG,	// object method argument type
+    WT_METHOD_RETURN	// object method return type
 } wherekind_T;
 
-// Struct used to pass to error messages about where the error happened.
+// Struct used to pass the location of a type check.  Used in error messages to
+// indicate where the error happened.  Also used for doing covariance type
+// check for object method return type and contra-variance type check for
+// object method arguments.
 typedef struct {
     char	*wt_func_name;  // function name or NULL
     char	wt_index;	// argument or variable index, 0 means unknown
-    wherekind_T	wt_kind;	// "variable" when TRUE, "argument" otherwise
+    wherekind_T	wt_kind;	// type check location
 } where_T;
 
 #define WHERE_INIT {NULL, 0, WT_UNKNOWN}
diff --git a/src/testdir/test_vim9_class.vim b/src/testdir/test_vim9_class.vim
index 9799a2f..856ee03 100644
--- a/src/testdir/test_vim9_class.vim
+++ b/src/testdir/test_vim9_class.vim
@@ -6318,4 +6318,80 @@
   endfor
 enddef
 
+" Test for checking the type of the arguments and the return value of a object
+" method in an extended class.
+def Test_extended_obj_method_type_check()
+  var lines =<< trim END
+    vim9script
+
+    class A
+    endclass
+    class B extends A
+    endclass
+    class C extends B
+    endclass
+
+    class Foo
+      def Doit(p: B): B
+        return B.new()
+      enddef
+    endclass
+
+    class Bar extends Foo
+      def Doit(p: A): C
+        return C.new()
+      enddef
+    endclass
+  END
+  v9.CheckSourceSuccess(lines)
+
+  lines =<< trim END
+    vim9script
+
+    class A
+    endclass
+    class B extends A
+    endclass
+    class C extends B
+    endclass
+
+    class Foo
+      def Doit(p: B): B
+        return B.new()
+      enddef
+    endclass
+
+    class Bar extends Foo
+      def Doit(p: C): B
+        return B.new()
+      enddef
+    endclass
+  END
+  v9.CheckSourceFailure(lines, 'E1383: Method "Doit": type mismatch, expected func(object<B>): object<B> but got func(object<C>): object<B>', 20)
+
+  lines =<< trim END
+    vim9script
+
+    class A
+    endclass
+    class B extends A
+    endclass
+    class C extends B
+    endclass
+
+    class Foo
+      def Doit(p: B): B
+        return B.new()
+      enddef
+    endclass
+
+    class Bar extends Foo
+      def Doit(p: B): A
+        return A.new()
+      enddef
+    endclass
+  END
+  v9.CheckSourceFailure(lines, 'E1383: Method "Doit": type mismatch, expected func(object<B>): object<B> but got func(object<B>): object<A>', 20)
+enddef
+
 " vim: ts=8 sw=2 sts=2 expandtab tw=80 fdm=marker
diff --git a/src/version.c b/src/version.c
index 766333e..e625ea6 100644
--- a/src/version.c
+++ b/src/version.c
@@ -700,6 +700,8 @@
 static int included_patches[] =
 {   /* Add new patch number below this line */
 /**/
+    1959,
+/**/
     1958,
 /**/
     1957,
diff --git a/src/vim9class.c b/src/vim9class.c
index 885ac03..790c2c3 100644
--- a/src/vim9class.c
+++ b/src/vim9class.c
@@ -2561,7 +2561,7 @@
 {
     for (cctx_T *cctx = cctx_arg; cctx != NULL; cctx = cctx->ctx_outer)
 	if (cctx->ctx_ufunc != NULL
-			&& class_instance_of(cctx->ctx_ufunc->uf_class, cl))
+			&& class_instance_of(cctx->ctx_ufunc->uf_class, cl, TRUE))
 	    return TRUE;
     return FALSE;
 }
@@ -2871,29 +2871,39 @@
  * interfaces matches the class "other_cl".
  */
     int
-class_instance_of(class_T *cl, class_T *other_cl)
+class_instance_of(class_T *cl, class_T *other_cl, int covariance_check)
 {
     if (cl == other_cl)
 	return TRUE;
 
-    // Recursively check the base classes.
-    for (; cl != NULL; cl = cl->class_extends)
+    if (covariance_check)
     {
-	if (cl == other_cl)
-	    return TRUE;
-	// Check the implemented interfaces and the super interfaces
-	for (int i = cl->class_interface_count - 1; i >= 0; --i)
+	// Recursively check the base classes.
+	for (; cl != NULL; cl = cl->class_extends)
 	{
-	    class_T	*intf = cl->class_interfaces_cl[i];
-	    while (intf != NULL)
+	    if (cl == other_cl)
+		return TRUE;
+	    // Check the implemented interfaces and the super interfaces
+	    for (int i = cl->class_interface_count - 1; i >= 0; --i)
 	    {
-		if (intf == other_cl)
-		    return TRUE;
-		// check the super interfaces
-		intf = intf->class_extends;
+		class_T	*intf = cl->class_interfaces_cl[i];
+		while (intf != NULL)
+		{
+		    if (intf == other_cl)
+			return TRUE;
+		    // check the super interfaces
+		    intf = intf->class_extends;
+		}
 	    }
 	}
     }
+    else
+    {
+	// contra-variance
+	for (; other_cl != NULL; other_cl = other_cl->class_extends)
+	    if (cl == other_cl)
+		return TRUE;
+    }
 
     return FALSE;
 }
@@ -2928,7 +2938,7 @@
 	    }
 
 	    if (class_instance_of(object_tv->vval.v_object->obj_class,
-			li->li_tv.vval.v_class) == TRUE)
+			li->li_tv.vval.v_class, TRUE) == TRUE)
 	    {
 		rettv->vval.v_number = VVAL_TRUE;
 		return;
@@ -2937,8 +2947,9 @@
     }
     else if (classinfo_tv->v_type == VAR_CLASS)
     {
-	rettv->vval.v_number = class_instance_of(object_tv->vval.v_object->obj_class,
-		classinfo_tv->vval.v_class);
+	rettv->vval.v_number = class_instance_of(
+					object_tv->vval.v_object->obj_class,
+					classinfo_tv->vval.v_class, TRUE);
     }
 }
 
diff --git a/src/vim9type.c b/src/vim9type.c
index 6ca4b29..4b8064d 100644
--- a/src/vim9type.c
+++ b/src/vim9type.c
@@ -759,6 +759,8 @@
 		    where.wt_func_name, typename1, typename2);
 	    break;
 	case WT_METHOD:
+	case WT_METHOD_ARG:
+	case WT_METHOD_RETURN:
 	    semsg(_(e_method_str_type_mismatch_expected_str_but_got_str),
 		    where.wt_func_name, typename1, typename2);
 	    break;
@@ -869,8 +871,15 @@
 	    {
 		if (actual->tt_member != NULL
 					    && actual->tt_member != &t_unknown)
+		{
+		    where_T  func_where = where;
+
+		    if (where.wt_kind == WT_METHOD)
+			func_where.wt_kind = WT_METHOD_RETURN;
 		    ret = check_type_maybe(expected->tt_member,
-					      actual->tt_member, FALSE, where);
+					    actual->tt_member, FALSE,
+					    func_where);
+		}
 		else
 		    ret = MAYBE;
 	    }
@@ -887,14 +896,20 @@
 
 		for (i = 0; i < expected->tt_argcount
 					       && i < actual->tt_argcount; ++i)
+		{
+		    where_T  func_where = where;
+		    if (where.wt_kind == WT_METHOD)
+			func_where.wt_kind = WT_METHOD_ARG;
+
 		    // Allow for using "any" argument type, lambda's have them.
 		    if (actual->tt_args[i] != &t_any && check_type(
 			    expected->tt_args[i], actual->tt_args[i], FALSE,
-								where) == FAIL)
+							func_where) == FAIL)
 		    {
 			ret = FAIL;
 			break;
 		    }
+		}
 	    }
 	    if (ret == OK && expected->tt_argcount >= 0
 						  && actual->tt_argcount == -1)
@@ -910,7 +925,10 @@
 	    if (actual->tt_class == NULL)
 		return OK;	// A null object matches
 
-	    if (class_instance_of(actual->tt_class, expected->tt_class) == FALSE)
+	    // For object method arguments, do a contra-variance type check in
+	    // an extended class.  For all others, do a co-variance type check.
+	    if (class_instance_of(actual->tt_class, expected->tt_class,
+				    where.wt_kind != WT_METHOD_ARG) == FALSE)
 		ret = FAIL;
 	}