Skip to content

Commit

Permalink
Fix memory leak when requiring or importing modules that get GC'd lat…
Browse files Browse the repository at this point in the history
…er (#12997)
  • Loading branch information
Jarred-Sumner authored Aug 1, 2024
1 parent e585f90 commit 59c5c0f
Show file tree
Hide file tree
Showing 14 changed files with 366 additions and 82 deletions.
69 changes: 55 additions & 14 deletions src/bun.js/bindings/CommonJSModuleRecord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
*/

#include "headers.h"

#include "JavaScriptCore/Synchronousness.h"
#include "JavaScriptCore/JSCast.h"
#include <JavaScriptCore/JSMapInlines.h>
#include "root.h"
Expand Down Expand Up @@ -72,6 +74,7 @@
#include "PathInlines.h"
#include "wtf/NakedPtr.h"
#include "wtf/URL.h"
#include "wtf/text/StringImpl.h"

extern "C" bool Bun__isBunMain(JSC::JSGlobalObject* global, const BunString*);

Expand Down Expand Up @@ -152,7 +155,6 @@ static bool evaluateCommonJSModuleOnce(JSC::VM& vm, Zig::GlobalObject* globalObj
JSValue fnValue = JSC::evaluate(globalObject, code, jsUndefined(), exception);

if (UNLIKELY(exception.get() || fnValue.isEmpty())) {

return false;
}

Expand All @@ -171,6 +173,13 @@ static bool evaluateCommonJSModuleOnce(JSC::VM& vm, Zig::GlobalObject* globalObj
args.append(Zig::ImportMetaObject::create(globalObject, filename));
}

// Clear the source code as early as possible.
code = {};

// Call the CommonJS module wrapper function.
//
// fn(exports, require, module, __filename, __dirname) { /* code */ }(exports, require, module, __filename, __dirname)
//
JSC::call(globalObject, fn, callData, moduleObject, args, exception);

return exception.get() == nullptr;
Expand Down Expand Up @@ -352,9 +361,15 @@ JSC_DEFINE_CUSTOM_GETTER(getterParent, (JSC::JSGlobalObject * globalObject, JSC:
if (UNLIKELY(!thisObject)) {
return JSValue::encode(jsUndefined());
}
auto v = thisObject->m_parent.get();
if (v)
return JSValue::encode(thisObject->m_parent.get());

if (thisObject->m_overridenParent) {
return JSValue::encode(thisObject->m_overridenParent.get());
}

if (thisObject->m_parent) {
auto* parent = thisObject->m_parent.get();
return JSValue::encode(parent);
}

// initialize parent by checking if it is the main module. we do this lazily because most people
// dont need `module.parent` and creating commonjs module records is done a ton.
Expand All @@ -363,12 +378,11 @@ JSC_DEFINE_CUSTOM_GETTER(getterParent, (JSC::JSGlobalObject * globalObject, JSC:
auto id = idValue->value(globalObject);
auto idStr = Bun::toString(id);
if (Bun__isBunMain(globalObject, &idStr)) {
thisObject->m_parent.set(globalObject->vm(), thisObject, jsNull());
thisObject->m_overridenParent.set(globalObject->vm(), thisObject, jsNull());
return JSValue::encode(jsNull());
}
}

thisObject->m_parent.set(globalObject->vm(), thisObject, jsUndefined());
return JSValue::encode(jsUndefined());
}

Expand Down Expand Up @@ -454,7 +468,15 @@ JSC_DEFINE_CUSTOM_SETTER(setterParent,
if (!thisObject)
return false;

thisObject->m_parent.set(globalObject->vm(), thisObject, JSValue::decode(value));
JSValue decodedValue = JSValue::decode(value);

if (auto* parent = jsDynamicCast<JSCommonJSModule*>(decodedValue)) {
thisObject->m_parent = parent;
thisObject->m_overridenParent.clear();
} else {
thisObject->m_parent = {};
thisObject->m_overridenParent.set(globalObject->vm(), thisObject, JSValue::decode(value));
}

return true;
}
Expand Down Expand Up @@ -640,7 +662,7 @@ JSC_DEFINE_HOST_FUNCTION(jsFunctionCreateCommonJSModule, (JSGlobalObject * globa
{
RELEASE_ASSERT(callframe->argumentCount() == 4);

auto id = callframe->uncheckedArgument(0).toWTFString(globalObject);
auto id = callframe->uncheckedArgument(0).toString(globalObject);
JSValue object = callframe->uncheckedArgument(1);
JSValue hasEvaluated = callframe->uncheckedArgument(2);
ASSERT(hasEvaluated.isBoolean());
Expand All @@ -651,15 +673,14 @@ JSC_DEFINE_HOST_FUNCTION(jsFunctionCreateCommonJSModule, (JSGlobalObject * globa

JSCommonJSModule* JSCommonJSModule::create(
Zig::GlobalObject* globalObject,
const WTF::String& key,
JSC::JSString* requireMapKey,
JSValue exportsObject,
bool hasEvaluated,
JSValue parent)
{
auto& vm = globalObject->vm();
JSString* requireMapKey = JSC::jsStringWithCache(vm, key);

auto index = key.reverseFind(PLATFORM_SEP, key.length());
auto key = requireMapKey->value(globalObject);
auto index = key->reverseFind(PLATFORM_SEP, key->length());

JSString* dirname;
if (index != WTF::notFound) {
Expand All @@ -679,11 +700,31 @@ JSCommonJSModule* JSCommonJSModule::create(
exportsObject,
0);
out->hasEvaluated = hasEvaluated;
out->m_parent.set(vm, out, parent);
if (parent && parent.isCell()) {
if (auto* parentModule = jsDynamicCast<JSCommonJSModule*>(parent)) {
out->m_parent = JSC::Weak<JSCommonJSModule>(parentModule);
} else {
out->m_overridenParent.set(vm, out, parent);
}
} else if (parent) {
out->m_overridenParent.set(vm, out, parent);
}

return out;
}

JSCommonJSModule* JSCommonJSModule::create(
Zig::GlobalObject* globalObject,
const WTF::String& key,
JSValue exportsObject,
bool hasEvaluated,
JSValue parent)
{
auto& vm = globalObject->vm();
auto* requireMapKey = JSC::jsStringWithCache(vm, key);
return JSCommonJSModule::create(globalObject, requireMapKey, exportsObject, hasEvaluated, parent);
}

void JSCommonJSModule::destroy(JSC::JSCell* cell)
{
static_cast<JSCommonJSModule*>(cell)->JSCommonJSModule::~JSCommonJSModule();
Expand Down Expand Up @@ -908,6 +949,7 @@ void JSCommonJSModule::visitChildrenImpl(JSCell* cell, Visitor& visitor)
visitor.append(thisObject->m_filename);
visitor.append(thisObject->m_dirname);
visitor.append(thisObject->m_paths);
visitor.append(thisObject->m_overridenParent);
}

DEFINE_VISIT_CHILDREN(JSCommonJSModule);
Expand Down Expand Up @@ -1014,7 +1056,6 @@ bool JSCommonJSModule::evaluate(
this->sourceCode = JSC::SourceCode(WTFMove(sourceProvider));

WTF::NakedPtr<JSC::Exception> exception;

evaluateCommonJSModuleOnce(vm, globalObject, this, this->m_dirname.get(), this->m_filename.get(), exception);

if (exception.get()) {
Expand Down
20 changes: 19 additions & 1 deletion src/bun.js/bindings/CommonJSModuleRecord.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "JavaScriptCore/JSGlobalObject.h"
#include "JavaScriptCore/JSString.h"
#include "root.h"
#include "headers-handwritten.h"
#include "wtf/NakedPtr.h"
Expand Down Expand Up @@ -35,7 +36,19 @@ class JSCommonJSModule final : public JSC::JSDestructibleObject {
mutable JSC::WriteBarrier<Unknown> m_filename;
mutable JSC::WriteBarrier<JSString> m_dirname;
mutable JSC::WriteBarrier<Unknown> m_paths;
mutable JSC::WriteBarrier<Unknown> m_parent;

// Visited by the GC. When the module is assigned a non-JSCommonJSModule
// parent, it is assigned to this field.
//
// module.parent = parent;
//
mutable JSC::WriteBarrier<Unknown> m_overridenParent;

// Not visited by the GC.
// When the module is assigned a JSCommonJSModule parent, it is assigned to this field.
// This is the normal state.
JSC::Weak<JSCommonJSModule> m_parent {};

bool ignoreESModuleAnnotation { false };
JSC::SourceCode sourceCode = JSC::SourceCode();

Expand Down Expand Up @@ -70,6 +83,11 @@ class JSCommonJSModule final : public JSC::JSDestructibleObject {
const WTF::String& key,
JSValue exportsObject, bool hasEvaluated, JSValue parent);

static JSCommonJSModule* create(
Zig::GlobalObject* globalObject,
JSC::JSString* key,
JSValue exportsObject, bool hasEvaluated, JSValue parent);

static JSCommonJSModule* create(
Zig::GlobalObject* globalObject,
const WTF::String& key,
Expand Down
57 changes: 22 additions & 35 deletions src/bun.js/bindings/ModuleLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ using namespace JSC;
using namespace Zig;
using namespace WebCore;

class ResolvedSourceCodeHolder {
public:
ResolvedSourceCodeHolder(ErrorableResolvedSource* res_)
: res(res_)
{
}

~ResolvedSourceCodeHolder()
{
if (res->success && res->result.value.source_code.tag == BunStringTag::WTFStringImpl && res->result.value.needsDeref) {
res->result.value.needsDeref = false;
res->result.value.source_code.impl.wtf->deref();
}
}

ErrorableResolvedSource* res;
};

extern "C" BunLoaderType Bun__getDefaultLoader(JSC::JSGlobalObject*, BunString* specifier);

static JSC::JSInternalPromise* rejectedInternalPromise(JSC::JSGlobalObject* globalObject, JSC::JSValue value)
Expand Down Expand Up @@ -289,13 +307,7 @@ static JSValue handleVirtualModuleResult(
auto onLoadResult = handleOnLoadResult(globalObject, virtualModuleResult, specifier, wasModuleMock);
JSC::VM& vm = globalObject->vm();
auto scope = DECLARE_THROW_SCOPE(vm);
WTF::String sourceCodeStringForDeref;
const auto getSourceCodeStringForDeref = [&]() {
if (res->success && res->result.value.needsDeref && res->result.value.source_code.tag == BunStringTag::WTFStringImpl) {
res->result.value.needsDeref = false;
sourceCodeStringForDeref = String(res->result.value.source_code.impl.wtf);
}
};
ResolvedSourceCodeHolder sourceCodeHolder(res);

const auto reject = [&](JSC::JSValue exception) -> JSValue {
if constexpr (allowPromise) {
Expand Down Expand Up @@ -342,7 +354,6 @@ static JSValue handleVirtualModuleResult(
if (!res->success) {
return reject(JSValue::decode(reinterpret_cast<EncodedJSValue>(res->result.err.ptr)));
}
getSourceCodeStringForDeref();

auto provider = Zig::SourceProvider::create(globalObject, res->result.value);
return resolve(JSC::JSSourceCode::create(vm, JSC::SourceCode(provider)));
Expand Down Expand Up @@ -396,13 +407,7 @@ extern "C" void Bun__onFulfillAsyncModule(
BunString* specifier,
BunString* referrer)
{
WTF::String sourceCodeStringForDeref;
const auto getSourceCodeStringForDeref = [&]() {
if (res->result.value.needsDeref && res->result.value.source_code.tag == BunStringTag::WTFStringImpl) {
res->result.value.needsDeref = false;
sourceCodeStringForDeref = String(res->result.value.source_code.impl.wtf);
}
};
ResolvedSourceCodeHolder sourceCodeHolder(res);
auto& vm = globalObject->vm();
auto scope = DECLARE_THROW_SCOPE(vm);
JSC::JSInternalPromise* promise = jsCast<JSC::JSInternalPromise*>(JSC::JSValue::decode(encodedPromiseValue));
Expand All @@ -414,7 +419,6 @@ extern "C" void Bun__onFulfillAsyncModule(
return promise->reject(globalObject, exception);
}

getSourceCodeStringForDeref();
auto specifierValue = Bun::toJS(globalObject, *specifier);

if (auto entry = globalObject->esmRegistryMap()->get(globalObject, specifierValue)) {
Expand Down Expand Up @@ -467,13 +471,7 @@ JSValue fetchCommonJSModule(
memset(&resValue, 0, sizeof(ErrorableResolvedSource));

ErrorableResolvedSource* res = &resValue;
WTF::String sourceCodeStringForDeref;
const auto getSourceCodeStringForDeref = [&]() {
if (res->success && res->result.value.needsDeref && res->result.value.source_code.tag == BunStringTag::WTFStringImpl) {
res->result.value.needsDeref = false;
sourceCodeStringForDeref = String(res->result.value.source_code.impl.wtf);
}
};
ResolvedSourceCodeHolder sourceCodeHolder(res);
auto& builtinNames = WebCore::clientData(vm)->builtinNames();

bool wasModuleMock = false;
Expand Down Expand Up @@ -600,8 +598,6 @@ JSValue fetchCommonJSModule(
}

Bun__transpileFile(bunVM, globalObject, specifier, referrer, typeAttribute, res, false);
getSourceCodeStringForDeref();

if (res->success && res->result.value.isCommonJSModule) {
target->evaluate(globalObject, specifier->toWTFString(BunString::ZeroCopy), res->result.value);
RETURN_IF_EXCEPTION(scope, {});
Expand Down Expand Up @@ -669,6 +665,7 @@ static JSValue fetchESMSourceCode(
void* bunVM = globalObject->bunVM();
auto& vm = globalObject->vm();
auto scope = DECLARE_THROW_SCOPE(vm);
ResolvedSourceCodeHolder sourceCodeHolder(res);

const auto reject = [&](JSC::JSValue exception) -> JSValue {
if constexpr (allowPromise) {
Expand Down Expand Up @@ -755,23 +752,13 @@ static JSValue fetchESMSourceCode(
}
}

WTF::String sourceCodeStringForDeref;
const auto getSourceCodeStringForDeref = [&]() {
if (res->success && res->result.value.needsDeref && res->result.value.source_code.tag == BunStringTag::WTFStringImpl) {
res->result.value.needsDeref = false;
sourceCodeStringForDeref = String(res->result.value.source_code.impl.wtf);
}
};

if constexpr (allowPromise) {
auto* pendingCtx = Bun__transpileFile(bunVM, globalObject, specifier, referrer, typeAttribute, res, true);
getSourceCodeStringForDeref();
if (pendingCtx) {
return pendingCtx;
}
} else {
Bun__transpileFile(bunVM, globalObject, specifier, referrer, typeAttribute, res, false);
getSourceCodeStringForDeref();
}

if (res->success && res->result.value.isCommonJSModule) {
Expand Down
Loading

0 comments on commit 59c5c0f

Please sign in to comment.