Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions lib/typeprof/core/ast/call.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module TypeProf::Core
class AST
class CallBaseNode < Node
def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_block, lenv)
def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_block, lenv, forwarding_arguments: false)
super(raw_node, lenv)

@recv = recv
Expand All @@ -20,6 +20,7 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc
@block_body = nil
@safe_navigation = raw_node.respond_to?(:safe_navigation?) && raw_node.safe_navigation?
@anonymous_block_forwarding = false
@forwarding_arguments = forwarding_arguments

if raw_args
args = []
Expand All @@ -30,7 +31,7 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc
args << raw_arg.expression
@splat_flags << true
when Prism::ForwardingArgumentsNode
# TODO: Support forwarding arguments
@forwarding_arguments = true
else
args << raw_arg
@splat_flags << false
Expand Down Expand Up @@ -98,10 +99,10 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc
attr_reader :positional_args, :splat_flags, :keyword_args
attr_reader :block_tbl, :block_f_args, :block_opt_positional_defaults, :block_body, :block_pass, :anonymous_block_forwarding
attr_reader :block_multi_targets
attr_reader :safe_navigation
attr_reader :safe_navigation, :forwarding_arguments

def subnodes = { recv:, positional_args:, keyword_args:, block_opt_positional_defaults:, block_body:, block_pass: }
def attrs = { mid:, splat_flags:, block_tbl:, block_f_args:, yield:, safe_navigation:, anonymous_block_forwarding: }
def attrs = { mid:, splat_flags:, block_tbl:, block_f_args:, yield:, safe_navigation:, anonymous_block_forwarding:, forwarding_arguments: }

def install0(genv)
recv = @recv ? @recv.install(genv) : @yield ? @lenv.get_var(:"*given_block") : @lenv.get_var(:"*self")
Expand All @@ -111,22 +112,30 @@ def install0(genv)
recv = NilFilter.new(genv, self, recv, false).next_vtx
end

positional_args = @positional_args.map do |arg|
if arg.is_a?(DummyNilNode)
@lenv.get_var(:"*anonymous_rest")
else
arg.install(genv)
if @forwarding_arguments
forward_a_args = (@lenv.forward_args || raise).to_actual_arguments(genv, @changes, self)
positional_args = forward_a_args.positionals
splat_flags = forward_a_args.splat_flags
keyword_args = forward_a_args.keywords
else
positional_args = @positional_args.map do |arg|
if arg.is_a?(DummyNilNode)
@lenv.get_var(:"*anonymous_rest")
else
arg.install(genv)
end
end
splat_flags = @splat_flags
keyword_args = @keyword_args ? @keyword_args.install(genv) : nil
end

keyword_args = @keyword_args ? @keyword_args.install(genv) : nil

if @block_body
block_body = @block_body # kinda type annotationty
block_tbl = @block_tbl || raise
block_body.lenv.forward_args = @lenv.forward_args
@lenv.locals.each {|var, vtx| block_body.lenv.locals[var] = vtx }
block_tbl.each {|var| block_body.lenv.locals[var] = Source.new(genv.nil_type) }
@block_body.lenv.locals[:"*self"] = @block_body.lenv.cref.get_self(genv)
block_body.lenv.locals[:"*self"] = block_body.lenv.cref.get_self(genv)

blk_f_args = []
if @block_f_args
Expand Down Expand Up @@ -156,7 +165,7 @@ def install0(genv)
block_body.lenv.set_var(var, vtx)
end
vars = []
@block_body.modified_vars(@lenv.locals.keys - block_tbl, vars)
block_body.modified_vars(@lenv.locals.keys - block_tbl, vars)
vars.uniq!
vars.each do |var|
vtx = @lenv.get_var(var)
Expand All @@ -165,9 +174,9 @@ def install0(genv)
block_body.lenv.set_var(var, nvtx)
end

@block_body.lenv.locals[:"*expected_block_ret"] = Vertex.new(self)
@block_body.install(genv)
@block_body.lenv.add_next_box(@changes.add_escape_box(genv, @block_body.ret))
block_body.lenv.locals[:"*expected_block_ret"] = Vertex.new(self)
block_body.install(genv)
block_body.lenv.add_next_box(@changes.add_escape_box(genv, block_body.ret))

vars.each do |var|
@changes.add_edge(genv, block_body.lenv.get_var(var), @lenv.get_var(var))
Expand All @@ -179,15 +188,17 @@ def install0(genv)
elem_vtx = @changes.add_splat_box(genv, blk_f_ary_arg, i).ret
@changes.add_edge(genv, elem_vtx, f_arg)
end
block = Block.new(self, blk_f_ary_arg, blk_f_args, @block_body.lenv.next_boxes)
block = Block.new(self, blk_f_ary_arg, blk_f_args, block_body.lenv.next_boxes)
blk_ty = Source.new(Type::Proc.new(genv, block))
elsif @block_pass
blk_ty = @block_pass.install(genv)
elsif @anonymous_block_forwarding
blk_ty = @lenv.get_var(:"*anonymous_block")
elsif @forwarding_arguments
blk_ty = forward_a_args.block
end

a_args = ActualArguments.new(positional_args, @splat_flags, keyword_args, blk_ty)
a_args = ActualArguments.new(positional_args, splat_flags, keyword_args, blk_ty)
box = @changes.add_method_call_box(genv, recv, @mid, a_args, !@recv)

block_body = @block_body
Expand Down Expand Up @@ -290,9 +301,9 @@ def initialize(raw_node, lenv)

class ForwardingSuperNode < CallBaseNode
def initialize(raw_node, lenv)
raw_args = nil # TODO: forward args properly
raw_args = nil
raw_block = raw_node.block
super(raw_node, nil, :"*super", nil, raw_args, nil, raw_block, lenv)
super(raw_node, nil, :"*super", nil, raw_args, nil, raw_block, lenv, forwarding_arguments: true)
end
end

Expand Down
17 changes: 17 additions & 0 deletions lib/typeprof/core/ast/method.rb
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,23 @@ def install0(genv)
block = @body.lenv.new_var(:"*given_block", self)
end

forward_opt_positionals = @opt_positionals.map do
elem_vtx = Vertex.new(self)
[Source.new(genv.gen_ary_type(elem_vtx)), elem_vtx]
end
forward_opt_keywords = @opt_keywords.map {|_name| Vertex.new(self) }
@body.lenv.forward_args = ForwardingArguments.new(
req_positionals,
forward_opt_positionals.map(&:first),
forward_opt_positionals.map(&:last),
rest_positionals,
post_positionals,
@req_keywords.zip(req_keywords),
@opt_keywords.zip(forward_opt_keywords),
rest_keywords,
block,
)

if @body
@body.lenv.locals[:"*expected_method_ret"] = Vertex.new(self)
@body.install(genv)
Expand Down
5 changes: 3 additions & 2 deletions lib/typeprof/core/env.rb
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def code_units_cache
end

class LocalEnv
def initialize(file_context, cref, locals, return_boxes)
def initialize(file_context, cref, locals, return_boxes, forward_args = nil)
@file_context = file_context
@cref = cref
@locals = locals
Expand All @@ -328,9 +328,11 @@ def initialize(file_context, cref, locals, return_boxes)
@next_boxes = []
@ivar_narrowings = {}
@strict_const_scope = false
@forward_args = forward_args
end

attr_reader :file_context, :cref, :locals, :return_boxes, :break_vtx, :next_boxes, :strict_const_scope
attr_accessor :forward_args

def path = @file_context&.path
def code_range_from_node(node)
Expand Down Expand Up @@ -365,7 +367,6 @@ def get_break_vtx
@break_vtx ||= Vertex.new(:break_vtx)
end


def push_ivar_narrowing(name, narrowing)
raise unless narrowing.is_a?(Narrowing::Constraint)
(@ivar_narrowings[name] ||= []) << narrowing
Expand Down
156 changes: 156 additions & 0 deletions lib/typeprof/core/env/method.rb
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,162 @@ def get_keyword_arg(genv, changes, name)
end
end

class ForwardingArguments
def initialize(req_positionals, opt_positionals, opt_positional_elems, rest_positionals, post_positionals, req_keyword_pairs, opt_keyword_pairs, rest_keywords, block)
@req_positionals = req_positionals
@opt_positionals = opt_positionals
@opt_positional_elems = opt_positional_elems
@rest_positionals = rest_positionals
@post_positionals = post_positionals
@req_keyword_pairs = req_keyword_pairs
@opt_keyword_pairs = opt_keyword_pairs
@rest_keywords = rest_keywords
@block = block
end

attr_reader :block

def to_actual_arguments(genv, changes, node)
positionals = @req_positionals.dup
splat_flags = ::Array.new(positionals.size, false)

@opt_positionals.each do |arg|
positionals << arg
splat_flags << true
end

if @rest_positionals
positionals << @rest_positionals
splat_flags << true
end

@post_positionals.each do |arg|
positionals << arg
splat_flags << false
end

keywords = build_keyword_args(genv, changes, node)
ActualArguments.new(positionals, splat_flags, keywords, @block)
end

def accept_actual_arguments(genv, changes, a_args)
if a_args.splat_flags.any?
start_rest = [a_args.splat_flags.index(true), @req_positionals.size + @opt_positionals.size].min
end_rest = [a_args.splat_flags.rindex(true) + 1, a_args.positionals.size - @post_positionals.size].max
rest_vtxs = a_args.get_rest_args(genv, changes, start_rest, end_rest)

@req_positionals.each_with_index do |f_vtx, i|
if i < start_rest
changes.add_edge(genv, a_args.positionals[i], f_vtx)
else
rest_vtxs.each do |vtx|
changes.add_edge(genv, vtx, f_vtx)
end
end
end

@opt_positional_elems.each_with_index do |elem_vtx, i|
i += @req_positionals.size
if i < start_rest
changes.add_edge(genv, a_args.positionals[i], elem_vtx)
else
rest_vtxs.each do |vtx|
changes.add_edge(genv, vtx, elem_vtx)
end
end
end

@post_positionals.each_with_index do |f_vtx, i|
i += a_args.positionals.size - @post_positionals.size
if end_rest <= i
changes.add_edge(genv, a_args.positionals[i], f_vtx)
else
rest_vtxs.each do |vtx|
changes.add_edge(genv, vtx, f_vtx)
end
end
end

else
@req_positionals.each_with_index do |f_vtx, i|
changes.add_edge(genv, a_args.positionals[i], f_vtx)
end

@post_positionals.each_with_index do |f_vtx, i|
i -= @post_positionals.size
changes.add_edge(genv, a_args.positionals[i], f_vtx)
end

start_rest = @req_positionals.size
end_rest = a_args.positionals.size - @post_positionals.size
i = 0
while i < @opt_positional_elems.size && start_rest < end_rest
changes.add_edge(genv, a_args.positionals[start_rest], @opt_positional_elems[i])
i += 1
start_rest += 1
end
end

changes.add_edge(genv, a_args.block, @block) if @block && a_args.block

return unless a_args.keywords

@req_keyword_pairs.each do |name, f_vtx|
changes.add_edge(genv, a_args.get_keyword_arg(genv, changes, name), f_vtx)
end

@opt_keyword_pairs.each do |name, f_vtx|
changes.add_edge(genv, a_args.get_keyword_arg(genv, changes, name), f_vtx)
end

if @rest_keywords
named_keys = @req_keyword_pairs.map(&:first) + @opt_keyword_pairs.map(&:first)
a_args.keywords.each_type do |kw_ty|
case kw_ty
when Type::Record
rest_fields = kw_ty.fields.reject {|key, _| named_keys.include?(key) }
base = kw_ty.base_type(genv)
rest_record = Type::Record.new(genv, rest_fields, base)
changes.add_edge(genv, Source.new(rest_record), @rest_keywords)
when Type::Hash, Type::Instance
changes.add_edge(genv, Source.new(kw_ty), @rest_keywords)
end
end
end
end

private

def build_keyword_args(genv, changes, node)
return nil if @req_keyword_pairs.empty? && @opt_keyword_pairs.empty? && !@rest_keywords
return @rest_keywords if @req_keyword_pairs.empty? && @opt_keyword_pairs.empty?

unified_key = Vertex.new(node)
unified_val = Vertex.new(node)
literal_pairs = {}

@req_keyword_pairs.each do |name, vtx|
changes.add_edge(genv, Source.new(Type::Symbol.new(genv, name)), unified_key)
changes.add_edge(genv, vtx, unified_val)
literal_pairs[name] = vtx
end

@opt_keyword_pairs.each do |name, vtx|
changes.add_edge(genv, Source.new(Type::Symbol.new(genv, name)), unified_key)
changes.add_edge(genv, vtx, unified_val)
end

base_hash_type = genv.gen_hash_type(unified_key, unified_val)
changes.add_hash_splat_box(genv, @rest_keywords, unified_key, unified_val) if @rest_keywords

if literal_pairs.empty?
Source.new(base_hash_type)
else
Source.new(Type::Record.new(genv, literal_pairs, base_hash_type))
end
end
end

class Block
#: (AST::CallBaseNode, Vertex, Array[Vertex], Array[EscapeBox]) -> void
def initialize(node, f_ary_arg, f_args, next_boxes)
Expand Down
6 changes: 4 additions & 2 deletions lib/typeprof/core/graph/box.rb
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,6 @@ def run0(genv, changes)
end

def pass_arguments(changes, genv, a_args)
a_args = normalize_keyword_hash_argument_for_def(a_args)

if a_args.splat_flags.any?
# there is at least one splat actual argument

Expand Down Expand Up @@ -898,7 +896,11 @@ def normalize_keyword_hash_argument_for_def(a_args)
end

def call(changes, genv, a_args, ret)
a_args = normalize_keyword_hash_argument_for_def(a_args)
if pass_arguments(changes, genv, a_args)
if @node.is_a?(AST::DefNode)
@node.body.lenv.forward_args&.accept_actual_arguments(genv, changes, a_args)
end
changes.add_edge(genv, a_args.block, @f_args.block) if @f_args.block && a_args.block

changes.add_edge(genv, @ret, ret)
Expand Down
Loading
Loading