@@ -38,7 +38,7 @@ use crate::{
38
38
infer:: { BreakableKind , CoerceMany , Diverges , coerce:: CoerceNever } ,
39
39
make_binders,
40
40
mir:: { BorrowKind , MirSpan , MutBorrowKind , ProjectionElem } ,
41
- to_chalk_trait_id,
41
+ to_assoc_type_id , to_chalk_trait_id,
42
42
traits:: FnTrait ,
43
43
utils:: { self , elaborate_clause_supertraits} ,
44
44
} ;
@@ -245,7 +245,7 @@ impl InferenceContext<'_> {
245
245
}
246
246
247
247
fn deduce_closure_kind_from_predicate_clauses (
248
- & self ,
248
+ & mut self ,
249
249
expected_ty : & Ty ,
250
250
clauses : impl DoubleEndedIterator < Item = WhereClause > ,
251
251
closure_kind : ClosureKind ,
@@ -378,7 +378,7 @@ impl InferenceContext<'_> {
378
378
}
379
379
380
380
fn deduce_sig_from_projection (
381
- & self ,
381
+ & mut self ,
382
382
closure_kind : ClosureKind ,
383
383
projection_ty : & ProjectionTy ,
384
384
projected_ty : & Ty ,
@@ -392,13 +392,16 @@ impl InferenceContext<'_> {
392
392
393
393
// For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits,
394
394
// for closures and async closures, respectively.
395
- match closure_kind {
396
- ClosureKind :: Closure | ClosureKind :: Async
397
- if self . fn_trait_kind_from_trait_id ( trait_) . is_some ( ) =>
398
- {
399
- self . extract_sig_from_projection ( projection_ty, projected_ty)
400
- }
401
- _ => None ,
395
+ let fn_trait_kind = self . fn_trait_kind_from_trait_id ( trait_) ?;
396
+ if !matches ! ( closure_kind, ClosureKind :: Closure | ClosureKind :: Async ) {
397
+ return None ;
398
+ }
399
+ if fn_trait_kind. is_async ( ) {
400
+ // If the expected trait is `AsyncFn(...) -> X`, we don't know what the return type is,
401
+ // but we do know it must implement `Future<Output = X>`.
402
+ self . extract_async_fn_sig_from_projection ( projection_ty, projected_ty)
403
+ } else {
404
+ self . extract_sig_from_projection ( projection_ty, projected_ty)
402
405
}
403
406
}
404
407
@@ -424,6 +427,39 @@ impl InferenceContext<'_> {
424
427
) ) )
425
428
}
426
429
430
+ fn extract_async_fn_sig_from_projection (
431
+ & mut self ,
432
+ projection_ty : & ProjectionTy ,
433
+ projected_ty : & Ty ,
434
+ ) -> Option < FnSubst < Interner > > {
435
+ let arg_param_ty = projection_ty. substitution . as_slice ( Interner ) [ 1 ] . assert_ty_ref ( Interner ) ;
436
+
437
+ let TyKind :: Tuple ( _, input_tys) = arg_param_ty. kind ( Interner ) else {
438
+ return None ;
439
+ } ;
440
+
441
+ let ret_param_future_output = projected_ty;
442
+ let ret_param_future = self . table . new_type_var ( ) ;
443
+ let future_output =
444
+ LangItem :: FutureOutput . resolve_type_alias ( self . db , self . resolver . krate ( ) ) ?;
445
+ let future_projection = crate :: AliasTy :: Projection ( crate :: ProjectionTy {
446
+ associated_ty_id : to_assoc_type_id ( future_output) ,
447
+ substitution : Substitution :: from1 ( Interner , ret_param_future. clone ( ) ) ,
448
+ } ) ;
449
+ self . table . register_obligation (
450
+ crate :: AliasEq { alias : future_projection, ty : ret_param_future_output. clone ( ) }
451
+ . cast ( Interner ) ,
452
+ ) ;
453
+
454
+ Some ( FnSubst ( Substitution :: from_iter (
455
+ Interner ,
456
+ input_tys. iter ( Interner ) . map ( |t| t. cast ( Interner ) ) . chain ( Some ( GenericArg :: new (
457
+ Interner ,
458
+ chalk_ir:: GenericArgData :: Ty ( ret_param_future) ,
459
+ ) ) ) ,
460
+ ) ) )
461
+ }
462
+
427
463
fn fn_trait_kind_from_trait_id ( & self , trait_id : hir_def:: TraitId ) -> Option < FnTrait > {
428
464
FnTrait :: from_lang_item ( self . db . lang_attr ( trait_id. into ( ) ) ?)
429
465
}
0 commit comments