diff --git a/lib/typeprof/core/ast/call.rb b/lib/typeprof/core/ast/call.rb index d228e823..101fb43a 100644 --- a/lib/typeprof/core/ast/call.rb +++ b/lib/typeprof/core/ast/call.rb @@ -1,7 +1,7 @@ module TypeProf::Core class AST class CallBaseNode < Node - def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_block, lenv) + def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_block, lenv, forwarding_arguments: false) super(raw_node, lenv) @recv = recv @@ -20,6 +20,7 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc @block_body = nil @safe_navigation = raw_node.respond_to?(:safe_navigation?) && raw_node.safe_navigation? @anonymous_block_forwarding = false + @forwarding_arguments = forwarding_arguments if raw_args args = [] @@ -30,7 +31,7 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc args << raw_arg.expression @splat_flags << true when Prism::ForwardingArgumentsNode - # TODO: Support forwarding arguments + @forwarding_arguments = true else args << raw_arg @splat_flags << false @@ -98,10 +99,10 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc attr_reader :positional_args, :splat_flags, :keyword_args attr_reader :block_tbl, :block_f_args, :block_opt_positional_defaults, :block_body, :block_pass, :anonymous_block_forwarding attr_reader :block_multi_targets - attr_reader :safe_navigation + attr_reader :safe_navigation, :forwarding_arguments def subnodes = { recv:, positional_args:, keyword_args:, block_opt_positional_defaults:, block_body:, block_pass: } - def attrs = { mid:, splat_flags:, block_tbl:, block_f_args:, yield:, safe_navigation:, anonymous_block_forwarding: } + def attrs = { mid:, splat_flags:, block_tbl:, block_f_args:, yield:, safe_navigation:, anonymous_block_forwarding:, forwarding_arguments: } def install0(genv) recv = @recv ? @recv.install(genv) : @yield ? @lenv.get_var(:"*given_block") : @lenv.get_var(:"*self") @@ -111,22 +112,30 @@ def install0(genv) recv = NilFilter.new(genv, self, recv, false).next_vtx end - positional_args = @positional_args.map do |arg| - if arg.is_a?(DummyNilNode) - @lenv.get_var(:"*anonymous_rest") - else - arg.install(genv) + if @forwarding_arguments + forward_a_args = (@lenv.forward_args || raise).to_actual_arguments(genv, @changes, self) + positional_args = forward_a_args.positionals + splat_flags = forward_a_args.splat_flags + keyword_args = forward_a_args.keywords + else + positional_args = @positional_args.map do |arg| + if arg.is_a?(DummyNilNode) + @lenv.get_var(:"*anonymous_rest") + else + arg.install(genv) + end end + splat_flags = @splat_flags + keyword_args = @keyword_args ? @keyword_args.install(genv) : nil end - keyword_args = @keyword_args ? @keyword_args.install(genv) : nil - if @block_body block_body = @block_body # kinda type annotationty block_tbl = @block_tbl || raise + block_body.lenv.forward_args = @lenv.forward_args @lenv.locals.each {|var, vtx| block_body.lenv.locals[var] = vtx } block_tbl.each {|var| block_body.lenv.locals[var] = Source.new(genv.nil_type) } - @block_body.lenv.locals[:"*self"] = @block_body.lenv.cref.get_self(genv) + block_body.lenv.locals[:"*self"] = block_body.lenv.cref.get_self(genv) blk_f_args = [] if @block_f_args @@ -156,7 +165,7 @@ def install0(genv) block_body.lenv.set_var(var, vtx) end vars = [] - @block_body.modified_vars(@lenv.locals.keys - block_tbl, vars) + block_body.modified_vars(@lenv.locals.keys - block_tbl, vars) vars.uniq! vars.each do |var| vtx = @lenv.get_var(var) @@ -165,9 +174,9 @@ def install0(genv) block_body.lenv.set_var(var, nvtx) end - @block_body.lenv.locals[:"*expected_block_ret"] = Vertex.new(self) - @block_body.install(genv) - @block_body.lenv.add_next_box(@changes.add_escape_box(genv, @block_body.ret)) + block_body.lenv.locals[:"*expected_block_ret"] = Vertex.new(self) + block_body.install(genv) + block_body.lenv.add_next_box(@changes.add_escape_box(genv, block_body.ret)) vars.each do |var| @changes.add_edge(genv, block_body.lenv.get_var(var), @lenv.get_var(var)) @@ -179,15 +188,17 @@ def install0(genv) elem_vtx = @changes.add_splat_box(genv, blk_f_ary_arg, i).ret @changes.add_edge(genv, elem_vtx, f_arg) end - block = Block.new(self, blk_f_ary_arg, blk_f_args, @block_body.lenv.next_boxes) + block = Block.new(self, blk_f_ary_arg, blk_f_args, block_body.lenv.next_boxes) blk_ty = Source.new(Type::Proc.new(genv, block)) elsif @block_pass blk_ty = @block_pass.install(genv) elsif @anonymous_block_forwarding blk_ty = @lenv.get_var(:"*anonymous_block") + elsif @forwarding_arguments + blk_ty = forward_a_args.block end - a_args = ActualArguments.new(positional_args, @splat_flags, keyword_args, blk_ty) + a_args = ActualArguments.new(positional_args, splat_flags, keyword_args, blk_ty) box = @changes.add_method_call_box(genv, recv, @mid, a_args, !@recv) block_body = @block_body @@ -290,9 +301,9 @@ def initialize(raw_node, lenv) class ForwardingSuperNode < CallBaseNode def initialize(raw_node, lenv) - raw_args = nil # TODO: forward args properly + raw_args = nil raw_block = raw_node.block - super(raw_node, nil, :"*super", nil, raw_args, nil, raw_block, lenv) + super(raw_node, nil, :"*super", nil, raw_args, nil, raw_block, lenv, forwarding_arguments: true) end end diff --git a/lib/typeprof/core/ast/method.rb b/lib/typeprof/core/ast/method.rb index 843dcca8..54c0c4cd 100644 --- a/lib/typeprof/core/ast/method.rb +++ b/lib/typeprof/core/ast/method.rb @@ -271,6 +271,23 @@ def install0(genv) block = @body.lenv.new_var(:"*given_block", self) end + forward_opt_positionals = @opt_positionals.map do + elem_vtx = Vertex.new(self) + [Source.new(genv.gen_ary_type(elem_vtx)), elem_vtx] + end + forward_opt_keywords = @opt_keywords.map {|_name| Vertex.new(self) } + @body.lenv.forward_args = ForwardingArguments.new( + req_positionals, + forward_opt_positionals.map(&:first), + forward_opt_positionals.map(&:last), + rest_positionals, + post_positionals, + @req_keywords.zip(req_keywords), + @opt_keywords.zip(forward_opt_keywords), + rest_keywords, + block, + ) + if @body @body.lenv.locals[:"*expected_method_ret"] = Vertex.new(self) @body.install(genv) diff --git a/lib/typeprof/core/env.rb b/lib/typeprof/core/env.rb index 91117468..3d84150a 100644 --- a/lib/typeprof/core/env.rb +++ b/lib/typeprof/core/env.rb @@ -319,7 +319,7 @@ def code_units_cache end class LocalEnv - def initialize(file_context, cref, locals, return_boxes) + def initialize(file_context, cref, locals, return_boxes, forward_args = nil) @file_context = file_context @cref = cref @locals = locals @@ -328,9 +328,11 @@ def initialize(file_context, cref, locals, return_boxes) @next_boxes = [] @ivar_narrowings = {} @strict_const_scope = false + @forward_args = forward_args end attr_reader :file_context, :cref, :locals, :return_boxes, :break_vtx, :next_boxes, :strict_const_scope + attr_accessor :forward_args def path = @file_context&.path def code_range_from_node(node) @@ -365,7 +367,6 @@ def get_break_vtx @break_vtx ||= Vertex.new(:break_vtx) end - def push_ivar_narrowing(name, narrowing) raise unless narrowing.is_a?(Narrowing::Constraint) (@ivar_narrowings[name] ||= []) << narrowing diff --git a/lib/typeprof/core/env/method.rb b/lib/typeprof/core/env/method.rb index 20c23dec..f7746e00 100644 --- a/lib/typeprof/core/env/method.rb +++ b/lib/typeprof/core/env/method.rb @@ -95,6 +95,162 @@ def get_keyword_arg(genv, changes, name) end end + class ForwardingArguments + def initialize(req_positionals, opt_positionals, opt_positional_elems, rest_positionals, post_positionals, req_keyword_pairs, opt_keyword_pairs, rest_keywords, block) + @req_positionals = req_positionals + @opt_positionals = opt_positionals + @opt_positional_elems = opt_positional_elems + @rest_positionals = rest_positionals + @post_positionals = post_positionals + @req_keyword_pairs = req_keyword_pairs + @opt_keyword_pairs = opt_keyword_pairs + @rest_keywords = rest_keywords + @block = block + end + + attr_reader :block + + def to_actual_arguments(genv, changes, node) + positionals = @req_positionals.dup + splat_flags = ::Array.new(positionals.size, false) + + @opt_positionals.each do |arg| + positionals << arg + splat_flags << true + end + + if @rest_positionals + positionals << @rest_positionals + splat_flags << true + end + + @post_positionals.each do |arg| + positionals << arg + splat_flags << false + end + + keywords = build_keyword_args(genv, changes, node) + ActualArguments.new(positionals, splat_flags, keywords, @block) + end + + def accept_actual_arguments(genv, changes, a_args) + if a_args.splat_flags.any? + start_rest = [a_args.splat_flags.index(true), @req_positionals.size + @opt_positionals.size].min + end_rest = [a_args.splat_flags.rindex(true) + 1, a_args.positionals.size - @post_positionals.size].max + rest_vtxs = a_args.get_rest_args(genv, changes, start_rest, end_rest) + + @req_positionals.each_with_index do |f_vtx, i| + if i < start_rest + changes.add_edge(genv, a_args.positionals[i], f_vtx) + else + rest_vtxs.each do |vtx| + changes.add_edge(genv, vtx, f_vtx) + end + end + end + + @opt_positional_elems.each_with_index do |elem_vtx, i| + i += @req_positionals.size + if i < start_rest + changes.add_edge(genv, a_args.positionals[i], elem_vtx) + else + rest_vtxs.each do |vtx| + changes.add_edge(genv, vtx, elem_vtx) + end + end + end + + @post_positionals.each_with_index do |f_vtx, i| + i += a_args.positionals.size - @post_positionals.size + if end_rest <= i + changes.add_edge(genv, a_args.positionals[i], f_vtx) + else + rest_vtxs.each do |vtx| + changes.add_edge(genv, vtx, f_vtx) + end + end + end + + else + @req_positionals.each_with_index do |f_vtx, i| + changes.add_edge(genv, a_args.positionals[i], f_vtx) + end + + @post_positionals.each_with_index do |f_vtx, i| + i -= @post_positionals.size + changes.add_edge(genv, a_args.positionals[i], f_vtx) + end + + start_rest = @req_positionals.size + end_rest = a_args.positionals.size - @post_positionals.size + i = 0 + while i < @opt_positional_elems.size && start_rest < end_rest + changes.add_edge(genv, a_args.positionals[start_rest], @opt_positional_elems[i]) + i += 1 + start_rest += 1 + end + end + + changes.add_edge(genv, a_args.block, @block) if @block && a_args.block + + return unless a_args.keywords + + @req_keyword_pairs.each do |name, f_vtx| + changes.add_edge(genv, a_args.get_keyword_arg(genv, changes, name), f_vtx) + end + + @opt_keyword_pairs.each do |name, f_vtx| + changes.add_edge(genv, a_args.get_keyword_arg(genv, changes, name), f_vtx) + end + + if @rest_keywords + named_keys = @req_keyword_pairs.map(&:first) + @opt_keyword_pairs.map(&:first) + a_args.keywords.each_type do |kw_ty| + case kw_ty + when Type::Record + rest_fields = kw_ty.fields.reject {|key, _| named_keys.include?(key) } + base = kw_ty.base_type(genv) + rest_record = Type::Record.new(genv, rest_fields, base) + changes.add_edge(genv, Source.new(rest_record), @rest_keywords) + when Type::Hash, Type::Instance + changes.add_edge(genv, Source.new(kw_ty), @rest_keywords) + end + end + end + end + + private + + def build_keyword_args(genv, changes, node) + return nil if @req_keyword_pairs.empty? && @opt_keyword_pairs.empty? && !@rest_keywords + return @rest_keywords if @req_keyword_pairs.empty? && @opt_keyword_pairs.empty? + + unified_key = Vertex.new(node) + unified_val = Vertex.new(node) + literal_pairs = {} + + @req_keyword_pairs.each do |name, vtx| + changes.add_edge(genv, Source.new(Type::Symbol.new(genv, name)), unified_key) + changes.add_edge(genv, vtx, unified_val) + literal_pairs[name] = vtx + end + + @opt_keyword_pairs.each do |name, vtx| + changes.add_edge(genv, Source.new(Type::Symbol.new(genv, name)), unified_key) + changes.add_edge(genv, vtx, unified_val) + end + + base_hash_type = genv.gen_hash_type(unified_key, unified_val) + changes.add_hash_splat_box(genv, @rest_keywords, unified_key, unified_val) if @rest_keywords + + if literal_pairs.empty? + Source.new(base_hash_type) + else + Source.new(Type::Record.new(genv, literal_pairs, base_hash_type)) + end + end + end + class Block #: (AST::CallBaseNode, Vertex, Array[Vertex], Array[EscapeBox]) -> void def initialize(node, f_ary_arg, f_args, next_boxes) diff --git a/lib/typeprof/core/graph/box.rb b/lib/typeprof/core/graph/box.rb index 1e62042c..30150bd8 100644 --- a/lib/typeprof/core/graph/box.rb +++ b/lib/typeprof/core/graph/box.rb @@ -760,8 +760,6 @@ def run0(genv, changes) end def pass_arguments(changes, genv, a_args) - a_args = normalize_keyword_hash_argument_for_def(a_args) - if a_args.splat_flags.any? # there is at least one splat actual argument @@ -898,7 +896,11 @@ def normalize_keyword_hash_argument_for_def(a_args) end def call(changes, genv, a_args, ret) + a_args = normalize_keyword_hash_argument_for_def(a_args) if pass_arguments(changes, genv, a_args) + if @node.is_a?(AST::DefNode) + @node.body.lenv.forward_args&.accept_actual_arguments(genv, changes, a_args) + end changes.add_edge(genv, a_args.block, @f_args.block) if @f_args.block && a_args.block changes.add_edge(genv, @ret, ret) diff --git a/scenario/args/forwarding_arguments.rb b/scenario/args/forwarding_arguments.rb index cddbdf6b..43090449 100644 --- a/scenario/args/forwarding_arguments.rb +++ b/scenario/args/forwarding_arguments.rb @@ -19,3 +19,37 @@ def foo(a, ...) class Object def foo: (Integer, *untyped, **untyped) -> Integer end + +## update +def foo(...) + bar(...) +end + +def bar(*a, **b) + [a, b] +end + +foo(1, x: 4, y: 5) + +## assert +class Object + def foo: (*Integer, **Integer) -> [Array[Integer], { x: Integer, y: Integer }] + def bar: (*Integer, **Integer) -> [Array[Integer], { x: Integer, y: Integer }] +end + +## update +def foo(...) + 1.times { bar(...) } +end + +def bar(*a, **b) + [a, b] +end + +foo(1, x: 4, y: 5) + +## assert +class Object + def foo: (*Integer, **Integer) -> Integer + def bar: (*Integer, **Integer) -> [Array[Integer], { x: Integer, y: Integer }] +end diff --git a/scenario/known-issues/forwarding-arguments.rb b/scenario/known-issues/forwarding-arguments.rb deleted file mode 100644 index 7f012071..00000000 --- a/scenario/known-issues/forwarding-arguments.rb +++ /dev/null @@ -1,16 +0,0 @@ -## update -def foo(...) - bar(...) -end - -def bar(*a, **b) - [a, b] -end - -foo(1, x: 4, y: 5) - -## assert -class Object - def foo: (*Integer, **Hash[:x | :y, Integer]) -> [Array[Integer], Hash[:x | :y, Integer]] - def bar: (*Integer, **Hash[:x | :y, Integer]) -> [Array[Integer], Hash[:x | :y, Integer]] -end diff --git a/scenario/misc/super.rb b/scenario/misc/super.rb index a097772d..a2531a06 100644 --- a/scenario/misc/super.rb +++ b/scenario/misc/super.rb @@ -36,3 +36,26 @@ def [](key) class StringifyKeyHash < Hash def []: (untyped) -> untyped end + +## update +class SuperBase + def foo(*a, **b) + [a, b] + end +end + +class SuperChild < SuperBase + def foo(...) + super(...) + end +end + +SuperChild.new.foo(1, x: 4, y: 5) + +## assert +class SuperBase + def foo: (*Integer, **Integer) -> [Array[Integer], { x: Integer, y: Integer }] +end +class SuperChild < SuperBase + def foo: (*Integer, **Integer) -> [Array[Integer], { x: Integer, y: Integer }] +end