> 文档中心 > ThreadLocal精进篇:子线程类InheritableThreadLocal

ThreadLocal精进篇:子线程类InheritableThreadLocal

 背景

ThreadLocal可以保证在当前运行线程中的变量不被其他并发下的线程共享。

但是如果在代码中需要使用多线程呢?

ThreadLocal是否该如何保证相关子线程下的数据的传递安全性呢?

InheritableThreadLocal给我们提供了一丝可能。

InheritableThreadLocal

InheritableThreadLocal源码

先简单瞄下源码

package java.lang;import java.lang.ref.*;public class InheritableThreadLocal extends ThreadLocal {     protected T childValue(T parentValue) { return parentValue;    } ThreadLocalMap getMap(Thread t) {return t.inheritableThreadLocals;    } void createMap(Thread t, T firstValue) { t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);    }}

可以很清楚的看到,InheritableThreadLocal是ThreadLcoal的子类。

 

测试代码

一切以代码为主,我们先上一段代码

package com.chuangyue;import com.google.common.util.concurrent.ThreadFactoryBuilder;import java.util.concurrent.*;/** * 可继承线程本地测试 * */public class InheritableThreadLocalTest {    public static void main(String[] args) { ThreadLocal threadLocal =   new ThreadLocal(); InheritableThreadLocal inheritableThreadLocal = new InheritableThreadLocal(); threadLocal.set(" I'm threadLocal "); inheritableThreadLocal.set("I'm inheritableThreadLocal"); test01(threadLocal,inheritableThreadLocal);    }    public static void  test01(ThreadLocal threadLocal ,InheritableThreadLocal inheritableThreadLocal){ ThreadFactory namedThreadFactory = new ThreadFactoryBuilder()  .setNameFormat("demo-pool-%d").build(); ExecutorService singleThreadPool = new ThreadPoolExecutor(1, 1,  0L, TimeUnit.MILLISECONDS,  new LinkedBlockingQueue(1024), namedThreadFactory, new ThreadPoolExecutor.AbortPolicy()); singleThreadPool.execute(()-> {     System.out.println("========= begin =========");   System.out.println(Thread.currentThread().getName()+" ThreadLocal: "+threadLocal.get());     System.out.println(Thread.currentThread().getName()+" InheritableThreadLocal: "+inheritableThreadLocal.get());     System.out.println("========= end ========="); });  singleThreadPool.shutdown();    }}

这段测试代码里面,只是设置了ThreadLocal和InheritableThreadLocal两个对象,并且对其进行赋值。

然后再在线程中打印对应的对象结果。

输出结果如下

 按照正常的理解,应该是两个值都要输出。

但是在这里,我们可以清楚的看到,ThreadLocal输出值为null

这是为什么呢?

 

还是老规矩,我们看下源码

ThreadLocal源码

package java.lang;import jdk.internal.misc.TerminatingThreadLocal;import java.lang.ref.*;import java.util.Objects;import java.util.concurrent.atomic.AtomicInteger;import java.util.function.Supplier;public class ThreadLocal {private final int threadLocalHashCode = nextHashCode();    /**     * The next hash code to be given out. Updated atomically. Starts at     * zero.     */    private static AtomicInteger nextHashCode = new AtomicInteger();    private static final int HASH_INCREMENT = 0x61c88647;    /**     * Returns the next hash code.     */    private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT);    }    protected T initialValue() { return null;    }     public static  ThreadLocal withInitial(Supplier supplier) { return new SuppliedThreadLocal(supplier);    }    /**     * Creates a thread local variable.     * @see #withInitial(java.util.function.Supplier)     */    public ThreadLocal() {    }     public T get() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) {     ThreadLocalMap.Entry e = map.getEntry(this);     if (e != null) {  @SuppressWarnings("unchecked")  T result = (T)e.value;  return result;     } } return setInitialValue();    }boolean isPresent() { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); return map != null && map.getEntry(this) != null;    } private T setInitialValue() { T value = initialValue(); Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) {     map.set(this, value); } else {     createMap(t, value); } if (this instanceof TerminatingThreadLocal) {     TerminatingThreadLocal.register((TerminatingThreadLocal) this); } return value;    }public void set(T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) {     map.set(this, value); } else {     createMap(t, value); }    } public void remove() {  ThreadLocalMap m = getMap(Thread.currentThread());  if (m != null) {      m.remove(this);  }     }    /**     * Get the map associated with a ThreadLocal. Overridden in     * InheritableThreadLocal.     *     * @param  t the current thread     * @return the map     */    ThreadLocalMap getMap(Thread t) { return t.threadLocals;    }    /**     * Create the map associated with a ThreadLocal. Overridden in     * InheritableThreadLocal.     *     * @param t the current thread     * @param firstValue value for the initial entry of the map     */    void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue);    }    /**     * Factory method to create map of inherited thread locals.     * Designed to be called only from Thread constructor.     *     * @param  parentMap the map associated with parent thread     * @return a map containing the parent's inheritable bindings     */    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) { return new ThreadLocalMap(parentMap);    }    /**     * Method childValue is visibly defined in subclass     * InheritableThreadLocal, but is internally defined here for the     * sake of providing createInheritedMap factory method without     * needing to subclass the map class in InheritableThreadLocal.     * This technique is preferable to the alternative of embedding     * instanceof tests in methods.     */    T childValue(T parentValue) { throw new UnsupportedOperationException();    }    /**     * An extension of ThreadLocal that obtains its initial value from     * the specified {@code Supplier}.     */    static final class SuppliedThreadLocal extends ThreadLocal { private final Supplier supplier; SuppliedThreadLocal(Supplier supplier) {     this.supplier = Objects.requireNonNull(supplier); } @Override protected T initialValue() {     return supplier.get(); }    }    /**     * ThreadLocalMap is a customized hash map suitable only for     * maintaining thread local values. No operations are exported     * outside of the ThreadLocal class. The class is package private to     * allow declaration of fields in class Thread.  To help deal with     * very large and long-lived usages, the hash table entries use     * WeakReferences for keys. However, since reference queues are not     * used, stale entries are guaranteed to be removed only when     * the table starts running out of space.     */    static class ThreadLocalMap { /**  * The entries in this hash map extend WeakReference, using  * its main ref field as the key (which is always a  * ThreadLocal object).  Note that null keys (i.e. entry.get()  * == null) mean that the key is no longer referenced, so the  * entry can be expunged from table.  Such entries are referred to  * as "stale entries" in the code that follows.  */ static class Entry extends WeakReference<ThreadLocal> {     /** The value associated with this ThreadLocal. */     Object value;     Entry(ThreadLocal k, Object v) {  super(k);  value = v;     } } /**  * The initial capacity -- MUST be a power of two.  */ private static final int INITIAL_CAPACITY = 16; /**  * The table, resized as necessary.  * table.length MUST always be a power of two.  */ private Entry[] table; /**  * The number of entries in the table.  */ private int size = 0; /**  * The next size value at which to resize.  */ private int threshold; // Default to 0 /**  * Set the resize threshold to maintain at worst a 2/3 load factor.  */ private void setThreshold(int len) {     threshold = len * 2 / 3; } /**  * Increment i modulo len.  */ private static int nextIndex(int i, int len) {     return ((i + 1 = 0) ? i - 1 : len - 1); } /**  * Construct a new map initially containing (firstKey, firstValue).  * ThreadLocalMaps are constructed lazily, so we only create  * one when we have at least one entry to put in it.  */ ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {     table = new Entry[INITIAL_CAPACITY];     int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);     table[i] = new Entry(firstKey, firstValue);     size = 1;     setThreshold(INITIAL_CAPACITY); } /**  * Construct a new map including all Inheritable ThreadLocals  * from given parent map. Called only by createInheritedMap.  *  * @param parentMap the map associated with parent thread.  */ private ThreadLocalMap(ThreadLocalMap parentMap) {     Entry[] parentTable = parentMap.table;     int len = parentTable.length;     setThreshold(len);     table = new Entry[len];     for (Entry e : parentTable) {  if (e != null) {      @SuppressWarnings("unchecked")      ThreadLocal key = (ThreadLocal) e.get();      if (key != null) {   Object value = key.childValue(e.value);   Entry c = new Entry(key, value);   int h = key.threadLocalHashCode & (len - 1);   while (table[h] != null)h = nextIndex(h, len);   table[h] = c;   size++;      }  }     } } /**  * Get the entry associated with key.  This method  * itself handles only the fast path: a direct hit of existing  * key. It otherwise relays to getEntryAfterMiss.  This is  * designed to maximize performance for direct hits, in part  * by making this method readily inlinable.  *  * @param  key the thread local object  * @return the entry associated with key, or null if no such  */ private Entry getEntry(ThreadLocal key) {     int i = key.threadLocalHashCode & (table.length - 1);     Entry e = table[i];     if (e != null && e.get() == key)  return e;     else  return getEntryAfterMiss(key, i, e); } /**  * Version of getEntry method for use when key is not found in  * its direct hash slot.  *  * @param  key the thread local object  * @param  i the table index for key's hash code  * @param  e the entry at table[i]  * @return the entry associated with key, or null if no such  */ private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {     Entry[] tab = table;     int len = tab.length;     while (e != null) {  ThreadLocal k = e.get();  if (k == key)      return e;  if (k == null)      expungeStaleEntry(i);  else      i = nextIndex(i, len);  e = tab[i];     }     return null; } /**  * Set the value associated with key.  *  * @param key the thread local object  * @param value the value to be set  */ private void set(ThreadLocal key, Object value) {     // We don't use a fast path as with get() because it is at     // least as common to use set() to create new entries as     // it is to replace existing ones, in which case, a fast     // path would fail more often than not.     Entry[] tab = table;     int len = tab.length;     int i = key.threadLocalHashCode & (len-1);     for (Entry e = tab[i];   e != null;   e = tab[i = nextIndex(i, len)]) {  ThreadLocal k = e.get();  if (k == key) {      e.value = value;      return;  }  if (k == null) {      replaceStaleEntry(key, value, i);      return;  }     }     tab[i] = new Entry(key, value);     int sz = ++size;     if (!cleanSomeSlots(i, sz) && sz >= threshold)  rehash(); } /**  * Remove the entry for key.  */ private void remove(ThreadLocal key) {     Entry[] tab = table;     int len = tab.length;     int i = key.threadLocalHashCode & (len-1);     for (Entry e = tab[i];   e != null;   e = tab[i = nextIndex(i, len)]) {  if (e.get() == key) {      e.clear();      expungeStaleEntry(i);      return;  }     } } /**  * Replace a stale entry encountered during a set operation  * with an entry for the specified key.  The value passed in  * the value parameter is stored in the entry, whether or not  * an entry already exists for the specified key.  *  * As a side effect, this method expunges all stale entries in the  * "run" containing the stale entry.  (A run is a sequence of entries  * between two null slots.)  *  * @param  key the key  * @param  value the value to be associated with key  * @param  staleSlot index of the first stale entry encountered while  *  searching for key.  */ private void replaceStaleEntry(ThreadLocal key, Object value,    int staleSlot) {     Entry[] tab = table;     int len = tab.length;     Entry e;     // Back up to check for prior stale entry in current run.     // We clean out whole runs at a time to avoid continual     // incremental rehashing due to garbage collector freeing     // up refs in bunches (i.e., whenever the collector runs).     int slotToExpunge = staleSlot;     for (int i = prevIndex(staleSlot, len);   (e = tab[i]) != null;   i = prevIndex(i, len))  if (e.get() == null)      slotToExpunge = i;     // Find either the key or trailing null slot of run, whichever     // occurs first     for (int i = nextIndex(staleSlot, len);   (e = tab[i]) != null;   i = nextIndex(i, len)) {  ThreadLocal k = e.get();  // If we find key, then we need to swap it  // with the stale entry to maintain hash table order.  // The newly stale slot, or any other stale slot  // encountered above it, can then be sent to expungeStaleEntry  // to remove or rehash all of the other entries in run.  if (k == key) {      e.value = value;      tab[i] = tab[staleSlot];      tab[staleSlot] = e;      // Start expunge at preceding stale entry if it exists      if (slotToExpunge == staleSlot)   slotToExpunge = i;      cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);      return;  }  // If we didn't find stale entry on backward scan, the  // first stale entry seen while scanning for key is the  // first still present in the run.  if (k == null && slotToExpunge == staleSlot)      slotToExpunge = i;     }     // If key not found, put new entry in stale slot     tab[staleSlot].value = null;     tab[staleSlot] = new Entry(key, value);     // If there are any other stale entries in run, expunge them     if (slotToExpunge != staleSlot)  cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); } /**  * Expunge a stale entry by rehashing any possibly colliding entries  * lying between staleSlot and the next null slot.  This also expunges  * any other stale entries encountered before the trailing null.  See  * Knuth, Section 6.4  *  * @param staleSlot index of slot known to have null key  * @return the index of the next null slot after staleSlot  * (all between staleSlot and this slot will have been checked  * for expunging).  */ private int expungeStaleEntry(int staleSlot) {     Entry[] tab = table;     int len = tab.length;     // expunge entry at staleSlot     tab[staleSlot].value = null;     tab[staleSlot] = null;     size--;     // Rehash until we encounter null     Entry e;     int i;     for (i = nextIndex(staleSlot, len);   (e = tab[i]) != null;   i = nextIndex(i, len)) {  ThreadLocal k = e.get();  if (k == null) {      e.value = null;      tab[i] = null;      size--;  } else {      int h = k.threadLocalHashCode & (len - 1);      if (h != i) {   tab[i] = null;   // Unlike Knuth 6.4 Algorithm R, we must scan until   // null because multiple entries could have been stale.   while (tab[h] != null)h = nextIndex(h, len);   tab[h] = e;      }  }     }     return i; } /**  * Heuristically scan some cells looking for stale entries.  * This is invoked when either a new element is added, or  * another stale one has been expunged. It performs a  * logarithmic number of scans, as a balance between no  * scanning (fast but retains garbage) and a number of scans  * proportional to number of elements, that would find all  * garbage but would cause some insertions to take O(n) time.  *  * @param i a position known NOT to hold a stale entry. The  * scan starts at the element after i.  *  * @param n scan control: {@code log2(n)} cells are scanned,  * unless a stale entry is found, in which case  * {@code log2(table.length)-1} additional cells are scanned.  * When called from insertions, this parameter is the number  * of elements, but when from replaceStaleEntry, it is the  * table length. (Note: all this could be changed to be either  * more or less aggressive by weighting n instead of just  * using straight log n. But this version is simple, fast, and  * seems to work well.)  *  * @return true if any stale entries have been removed.  */ private boolean cleanSomeSlots(int i, int n) {     boolean removed = false;     Entry[] tab = table;     int len = tab.length;     do {  i = nextIndex(i, len);  Entry e = tab[i];  if (e != null && e.get() == null) {      n = len;      removed = true;      i = expungeStaleEntry(i);  }     } while ( (n >>>= 1) != 0);     return removed; } /**  * Re-pack and/or re-size the table. First scan the entire  * table removing stale entries. If this doesn't sufficiently  * shrink the size of the table, double the table size.  */ private void rehash() {     expungeStaleEntries();     // Use lower threshold for doubling to avoid hysteresis     if (size >= threshold - threshold / 4)  resize(); } /**  * Double the capacity of the table.  */ private void resize() {     Entry[] oldTab = table;     int oldLen = oldTab.length;     int newLen = oldLen * 2;     Entry[] newTab = new Entry[newLen];     int count = 0;     for (Entry e : oldTab) {  if (e != null) {      ThreadLocal k = e.get();      if (k == null) {   e.value = null; // Help the GC      } else {   int h = k.threadLocalHashCode & (newLen - 1);   while (newTab[h] != null)h = nextIndex(h, newLen);   newTab[h] = e;   count++;      }  }     }     setThreshold(newLen);     size = count;     table = newTab; } /**  * Expunge all stale entries in the table.  */ private void expungeStaleEntries() {     Entry[] tab = table;     int len = tab.length;     for (int j = 0; j < len; j++) {  Entry e = tab[j];  if (e != null && e.get() == null)      expungeStaleEntry(j);     } }    }}

我们可以很清楚的看到,在ThreadLocal中也存在着以下三个方法:getMap,createMap和chuildValue。

    ThreadLocalMap getMap(Thread t) { return t.threadLocals;    }void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue);    } T childValue(T parentValue) { throw new UnsupportedOperationException();    }

 我们看下InherittableThreadLocal源码中,对重写的这三个方法的代码注解

    /**     * 作为创建子线程时父线程值的函数,计算可继承线程局部变量的子线程初始值.      * 此方法在子线程启动之前从父线程内调用.     * 

* 此方法仅返回其输入参数,如果需要其他行为,则应重写此方法. * * @param parentValue 父线程值 * @return 子线程初始化值 */ protected T childValue(T parentValue) { return parentValue; }

/**     * 获取与ThreadLocal相关的Map对象     *     * @param t 当前线程     */    ThreadLocalMap getMap(Thread t) {return t.inheritableThreadLocals;    }
    /**     * 创建一个与ThreadLocal相关的map对象.     *     * @param t 当前线程     * @param firstValue 表初始条目值.     */    void createMap(Thread t, T firstValue) { t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);    }

我们直接跟踪下代码

这个是执行thread的时候,发现当前线程直接为空了。 

 

为什么ThreadLocal值为空

这个问题主要出在Thread上。

我们在Thread代码中找到了这两行

public class Thread implements Runnable {    ...     ThreadLocal.ThreadLocalMap threadLocals = null;  ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;...

每个主线程都会有一个自己的ThreadLocalMap,所以子线程在调用get方法拿值的时候其实访问的是自己的ThreadLocalMap,这个Map和主线程的Map是两个不同的对象,所以肯定是拿不到值的。ThreadLocalMap初始值是null,所以返回就是null了。这个没毛病

那么,inheritableThreadLocals初始化也为null,为什么它就有值呢?

Thread初探

 

Thread自己怎么玩的

在研究这个之前,我们先看下Thread自己是怎么玩的(初始化)

   //构造函数    public Thread(ThreadGroup group, String name) { this(group, null, name, 0);    }    public Thread(Runnable target, String name) { this(null, target, name, 0);    }     public Thread(ThreadGroup group, Runnable target, String name,    long stackSize) { this(group, target, name, stackSize, null, true);    }...

最后它们都甩锅给了它:private Thread

private Thread(ThreadGroup g, Runnable target, String name,     long stackSize, AccessControlContext acc,     boolean inheritThreadLocals) { if (name == null) {     throw new NullPointerException("name cannot be null"); } this.name = name; Thread parent = currentThread(); SecurityManager security = System.getSecurityManager(); if (g == null) {     /* Determine if it's an applet or not */     /* If there is a security manager, ask the security manager what to do. */     if (security != null) {  g = security.getThreadGroup();     }     /* If the security manager doesn't have a strong opinion on the matter, use the parent thread group. */     if (g == null) {  g = parent.getThreadGroup();     } } /* checkAccess regardless of whether or not threadgroup is    explicitly passed in. */ g.checkAccess(); /*  * Do we have the required permissions?  */ if (security != null) {     if (isCCLOverridden(getClass())) {  security.checkPermission(   SecurityConstants.SUBCLASS_IMPLEMENTATION_PERMISSION);     } } g.addUnstarted(); this.group = g; this.daemon = parent.isDaemon(); this.priority = parent.getPriority(); if (security == null || isCCLOverridden(parent.getClass()))     this.contextClassLoader = parent.getContextClassLoader(); else     this.contextClassLoader = parent.contextClassLoader; this.inheritedAccessControlContext =  acc != null ? acc : AccessController.getContext(); this.target = target; setPriority(priority); if (inheritThreadLocals && parent.inheritableThreadLocals != null)     this.inheritableThreadLocals =  ThreadLocal.createInheritedMap(parent.inheritableThreadLocals); /* Stash the specified stack size in case the VM cares */ this.stackSize = stackSize; /* Set thread ID */ this.tid = nextThreadID();    }。。。

我们可以看到,默认情况下,inheritThreadLocals的值是true。也就是设置inheritableThreadLocal默认是可传递的。

ThreadLocal.createInheritedMap

我们可以看到上面的代码中这么一行代码,来进行创建inheritedMap的。

我们继续剖析createInheritedMap源码

在createInheritedMap中,将所有的父线程中的Map的值,使用for的方式全部复制到子线程中。

     static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) { return new ThreadLocalMap(parentMap);     } /**  * 构建一个新的ThreadLocalsMap,这个ThreadLocalsMap包含所有parentMap中  * Inheritable ThreadLocals. 它只能被createInheritedMap调用.  *  */ private ThreadLocalMap(ThreadLocalMap parentMap) {     Entry[] parentTable = parentMap.table;     int len = parentTable.length;     setThreshold(len);     table = new Entry[len];     for (Entry e : parentTable) {  if (e != null) {      @SuppressWarnings("unchecked")      ThreadLocal key = (ThreadLocal) e.get();      if (key != null) {   Object value = key.childValue(e.value);   Entry c = new Entry(key, value);   int h = key.threadLocalHashCode & (len - 1);   while (table[h] != null)h = nextIndex(h, len);   table[h] = c;   size++;      }  }     } }

 

总结

所以,Thread类中包含的 threadLocalsinheritableThreadLocals 两个变量,inheritableThreadLocals 可自动向子线程中传递的ThreadLocal.ThreadLocalMap。这样,在你进行get()操作的时候,自然就能输出值了。